From 277059da696dd3de3de65600407e7c5c0be70855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=A6=E5=A2=83=E8=BF=B7=E7=A6=BB?= Date: Sat, 24 Jul 2021 17:37:52 +0800 Subject: [PATCH] fix bug, refactor (#81) * 1.code refactor 2.Fixed constructor parameter mismatch when the first code block of coritization has multiple parameters --- .../scala/io/github/dreamylost/apply.scala | 2 +- .../scala/io/github/dreamylost/builder.scala | 2 +- .../io/github/dreamylost/constructor.scala | 2 +- .../github/dreamylost/equalsAndHashCode.scala | 2 +- .../scala/io/github/dreamylost/json.scala | 2 +- src/main/scala/io/github/dreamylost/log.scala | 2 +- .../io/github/dreamylost/logs/LogType.scala | 6 +- ...mon.scala => AbstractMacroProcessor.scala} | 123 +++------- .../github/dreamylost/macros/applyMacro.scala | 63 ++--- .../dreamylost/macros/builderMacro.scala | 108 ++++----- .../dreamylost/macros/constructorMacro.scala | 152 ++++++------ .../macros/equalsAndHashCodeMacro.scala | 176 +++++++------- .../github/dreamylost/macros/jsonMacro.scala | 71 +++--- .../github/dreamylost/macros/logMacro.scala | 93 ++++---- .../dreamylost/macros/synchronizedMacro.scala | 49 ++-- .../dreamylost/macros/toStringMacro.scala | 222 +++++++++--------- .../io/github/dreamylost/synchronized.scala | 2 +- .../scala/io/github/dreamylost/toString.scala | 2 +- .../io/github/dreamylost/ApplyTest.scala | 14 +- .../io/github/dreamylost/BuilderTest.scala | 1 + .../dreamylost/EqualsAndHashCodeTest.scala | 32 ++- .../scala/io/github/dreamylost/LogTest.scala | 19 +- .../io/github/dreamylost/OthersTest.scala | 42 ++++ 23 files changed, 624 insertions(+), 563 deletions(-) rename src/main/scala/io/github/dreamylost/macros/{MacroCommon.scala => AbstractMacroProcessor.scala} (68%) create mode 100644 src/test/scala/io/github/dreamylost/OthersTest.scala diff --git a/src/main/scala/io/github/dreamylost/apply.scala b/src/main/scala/io/github/dreamylost/apply.scala index c8ae04c..10c725c 100644 --- a/src/main/scala/io/github/dreamylost/apply.scala +++ b/src/main/scala/io/github/dreamylost/apply.scala @@ -37,5 +37,5 @@ import scala.annotation.{ StaticAnnotation, compileTimeOnly } final class apply( verbose: Boolean = false ) extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro applyMacro.impl + def macroTransform(annottees: Any*): Any = macro applyMacro.applyProcessor.impl } diff --git a/src/main/scala/io/github/dreamylost/builder.scala b/src/main/scala/io/github/dreamylost/builder.scala index 023ed8d..5fff8f0 100644 --- a/src/main/scala/io/github/dreamylost/builder.scala +++ b/src/main/scala/io/github/dreamylost/builder.scala @@ -34,5 +34,5 @@ import scala.annotation.{ StaticAnnotation, compileTimeOnly } */ @compileTimeOnly("enable macro to expand macro annotations") final class builder extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro builderMacro.impl + def macroTransform(annottees: Any*): Any = macro builderMacro.BuilderProcessor.impl } diff --git a/src/main/scala/io/github/dreamylost/constructor.scala b/src/main/scala/io/github/dreamylost/constructor.scala index 43dfd68..e2fe9e9 100644 --- a/src/main/scala/io/github/dreamylost/constructor.scala +++ b/src/main/scala/io/github/dreamylost/constructor.scala @@ -39,5 +39,5 @@ final class constructor( verbose: Boolean = false, excludeFields: Seq[String] = Nil ) extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro constructorMacro.impl + def macroTransform(annottees: Any*): Any = macro constructorMacro.ConstructorProcessor.impl } diff --git a/src/main/scala/io/github/dreamylost/equalsAndHashCode.scala b/src/main/scala/io/github/dreamylost/equalsAndHashCode.scala index deaf795..578fc81 100644 --- a/src/main/scala/io/github/dreamylost/equalsAndHashCode.scala +++ b/src/main/scala/io/github/dreamylost/equalsAndHashCode.scala @@ -40,5 +40,5 @@ final class equalsAndHashCode( excludeFields: Seq[String] = Nil ) extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro equalsAndHashCodeMacro.impl + def macroTransform(annottees: Any*): Any = macro equalsAndHashCodeMacro.EqualsAndHashCodeProcessor.impl } diff --git a/src/main/scala/io/github/dreamylost/json.scala b/src/main/scala/io/github/dreamylost/json.scala index b6f58bd..0e52537 100644 --- a/src/main/scala/io/github/dreamylost/json.scala +++ b/src/main/scala/io/github/dreamylost/json.scala @@ -34,5 +34,5 @@ import scala.annotation.{ StaticAnnotation, compileTimeOnly } */ @compileTimeOnly("enable macro to expand macro annotations") final class json extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro jsonMacro.impl + def macroTransform(annottees: Any*): Any = macro jsonMacro.JsonProcessor.impl } diff --git a/src/main/scala/io/github/dreamylost/log.scala b/src/main/scala/io/github/dreamylost/log.scala index d8d2f81..07ea4d5 100644 --- a/src/main/scala/io/github/dreamylost/log.scala +++ b/src/main/scala/io/github/dreamylost/log.scala @@ -40,5 +40,5 @@ final class log( verbose: Boolean = false, logType: LogType.LogType = LogType.JLog ) extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro logMacro.impl + def macroTransform(annottees: Any*): Any = macro logMacro.LogProcessor.impl } diff --git a/src/main/scala/io/github/dreamylost/logs/LogType.scala b/src/main/scala/io/github/dreamylost/logs/LogType.scala index 5a956f6..218023f 100644 --- a/src/main/scala/io/github/dreamylost/logs/LogType.scala +++ b/src/main/scala/io/github/dreamylost/logs/LogType.scala @@ -39,9 +39,11 @@ object LogType extends Enumeration { } def getLogType(shortType: String): LogType = { - val tpe = PACKAGE + "." + shortType + val tpe1 = s"$PACKAGE.logs.$shortType" //LogType.JLog + val tpe2 = s"$PACKAGE.logs.LogType.$shortType" // JLog val v = LogType.values.find(p => { - s"$PACKAGE.${p.toString}" == tpe || s"$PACKAGE.$LogType.${p.toString}" == tpe + s"$PACKAGE.logs.LogType.${p.toString}" == tpe1 || + s"$PACKAGE.logs.LogType.${p.toString}" == tpe2 || s"$PACKAGE.logs.LogType.${p.toString}" == shortType }).getOrElse(throw new Exception(s"Not support log type: $shortType")).toString LogType.withName(v) } diff --git a/src/main/scala/io/github/dreamylost/macros/MacroCommon.scala b/src/main/scala/io/github/dreamylost/macros/AbstractMacroProcessor.scala similarity index 68% rename from src/main/scala/io/github/dreamylost/macros/MacroCommon.scala rename to src/main/scala/io/github/dreamylost/macros/AbstractMacroProcessor.scala index 9301593..74f208e 100644 --- a/src/main/scala/io/github/dreamylost/macros/MacroCommon.scala +++ b/src/main/scala/io/github/dreamylost/macros/AbstractMacroProcessor.scala @@ -24,48 +24,44 @@ package io.github.dreamylost.macros import scala.reflect.macros.whitebox /** - * Common methods * * @author 梦境迷离 - * @since 2021/6/28 + * @since 2021/7/24 * @version 1.0 */ -trait MacroCommon { +abstract class AbstractMacroProcessor(val c: whitebox.Context) { + import c.universe._ + + def impl(annottees: Expr[Any]*): Expr[Any] /** * Eval tree. * - * @param c * @param tree * @tparam T * @return */ - def evalTree[T: c.WeakTypeTag](c: whitebox.Context)(tree: c.Tree): T = c.eval(c.Expr[T](c.untypecheck(tree.duplicate))) - - def extractArgumentsTuple1[T: c.WeakTypeTag](c: whitebox.Context)(partialFunction: PartialFunction[c.Tree, Tuple1[T]]): Tuple1[T] = { - partialFunction.apply(c.prefix.tree) - } + def evalTree[T: WeakTypeTag](tree: Tree): T = c.eval(c.Expr[T](c.untypecheck(tree.duplicate))) - def extractArgumentsTuple2[T1: c.WeakTypeTag, T2: c.WeakTypeTag](c: whitebox.Context)(partialFunction: PartialFunction[c.Tree, (T1, T2)]): (T1, T2) = { + def extractArgumentsTuple1[T: WeakTypeTag](partialFunction: PartialFunction[Tree, Tuple1[T]]): Tuple1[T] = { partialFunction.apply(c.prefix.tree) } - def extractArgumentsTuple3[T1: c.WeakTypeTag, T2: c.WeakTypeTag, T3: c.WeakTypeTag](c: whitebox.Context)(partialFunction: PartialFunction[c.Tree, (T1, T2, T3)]): (T1, T2, T3) = { + def extractArgumentsTuple2[T1: WeakTypeTag, T2: WeakTypeTag](partialFunction: PartialFunction[Tree, (T1, T2)]): (T1, T2) = { partialFunction.apply(c.prefix.tree) } - def extractArgumentsTuple4[T1: c.WeakTypeTag, T2: c.WeakTypeTag, T3: c.WeakTypeTag, T4: c.WeakTypeTag](c: whitebox.Context)(partialFunction: PartialFunction[c.Tree, (T1, T2, T3, T4)]): (T1, T2, T3, T4) = { + def extractArgumentsTuple4[T1: WeakTypeTag, T2: WeakTypeTag, T3: WeakTypeTag, T4: WeakTypeTag](partialFunction: PartialFunction[Tree, (T1, T2, T3, T4)]): (T1, T2, T3, T4) = { partialFunction.apply(c.prefix.tree) } /** * Output ast result. * - * @param c * @param force * @param resTree */ - def printTree(c: whitebox.Context)(force: Boolean, resTree: c.Tree): Unit = { + def printTree(force: Boolean, resTree: Tree): Unit = { c.info( c.enclosingPosition, "\n###### Expanded macro ######\n" + resTree.toString() + "\n###### Expanded macro ######\n", @@ -76,12 +72,10 @@ trait MacroCommon { /** * Check the class and its companion object, and return the class definition. * - * @param c * @param annottees * @return Return ClassDef */ - def checkAndGetClassDef(c: whitebox.Context)(annottees: c.Expr[Any]*): c.universe.ClassDef = { - import c.universe._ + def checkAndGetClassDef(annottees: Expr[Any]*): ClassDef = { annottees.map(_.tree).toList match { case (classDecl: ClassDef) :: Nil => classDecl case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => classDecl @@ -89,31 +83,13 @@ trait MacroCommon { } } - /** - * Get class if it exists. - * - * @param c - * @param annottees - * @return Return ClassDef without verify. - */ - def tryGetClassDef(c: whitebox.Context)(annottees: c.Expr[Any]*): Option[c.universe.ClassDef] = { - import c.universe._ - annottees.map(_.tree).toList match { - case (classDecl: ClassDef) :: Nil => Some(classDecl) - case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => Some(classDecl) - case _ => None - } - } - /** * Get companion object if it exists. * - * @param c * @param annottees * @return */ - def tryGetCompanionObject(c: whitebox.Context)(annottees: c.Expr[Any]*): Option[c.universe.ModuleDef] = { - import c.universe._ + def tryGetCompanionObject(annottees: Expr[Any]*): Option[ModuleDef] = { annottees.map(_.tree).toList match { case (classDecl: ClassDef) :: Nil => None case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => Some(compDecl) @@ -125,14 +101,12 @@ trait MacroCommon { /** * Wrap tree result with companion object. * - * @param c * @param resTree class * @param annottees * @return */ - def treeResultWithCompanionObject(c: whitebox.Context)(resTree: c.Tree, annottees: c.Expr[Any]*): c.universe.Tree = { - import c.universe._ - val companionOpt = tryGetCompanionObject(c)(annottees: _*) + def treeResultWithCompanionObject(resTree: Tree, annottees: Expr[Any]*): Tree = { + val companionOpt = tryGetCompanionObject(annottees: _*) if (companionOpt.isEmpty) { resTree } else { @@ -146,17 +120,15 @@ trait MacroCommon { /** * Modify the associated object itself according to whether there is an associated object. * - * @param c * @param annottees * @param modifyAction The actual processing function * @return Return the result of modifyAction */ - def handleWithImplType(c: whitebox.Context)(annottees: c.Expr[Any]*) - (modifyAction: (c.universe.ClassDef, Option[c.universe.ModuleDef]) => Any): c.Expr[Nothing] = { - import c.universe._ + def handleWithImplType(annottees: Expr[Any]*) + (modifyAction: (ClassDef, Option[ModuleDef]) => Any): Expr[Nothing] = { annottees.map(_.tree) match { - case (classDecl: ClassDef) :: Nil => modifyAction(classDecl, None).asInstanceOf[c.Expr[Nothing]] - case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => modifyAction(classDecl, Some(compDecl)).asInstanceOf[c.Expr[Nothing]] + case (classDecl: ClassDef) :: Nil => modifyAction(classDecl, None).asInstanceOf[Expr[Nothing]] + case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => modifyAction(classDecl, Some(compDecl)).asInstanceOf[Expr[Nothing]] case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) } } @@ -164,12 +136,10 @@ trait MacroCommon { /** * Expand the class and check whether the class is a case class. * - * @param c * @param annotateeClass classDef * @return Return true if it is a case class */ - def isCaseClass(c: whitebox.Context)(annotateeClass: c.universe.ClassDef): Boolean = { - import c.universe._ + def isCaseClass(annotateeClass: ClassDef): Boolean = { annotateeClass match { case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => if (mods.asInstanceOf[Modifiers].hasFlag(Flag.CASE)) { @@ -183,12 +153,10 @@ trait MacroCommon { /** * Expand the constructor and get the field TermName. * - * @param c * @param field * @return */ - def getFieldTermName(c: whitebox.Context)(field: c.universe.Tree): c.universe.TermName = { - import c.universe._ + def getFieldTermName(field: Tree): TermName = { field match { case q"$mods val $tname: $tpt = $expr" => tname.asInstanceOf[TermName] case q"$mods var $tname: $tpt = $expr" => tname.asInstanceOf[TermName] @@ -200,12 +168,10 @@ trait MacroCommon { /** * Expand the method params and get the param Name. * - * @param c * @param field * @return */ - def getMethodParamName(c: whitebox.Context)(field: c.universe.Tree): c.universe.Name = { - import c.universe._ + def getMethodParamName(field: Tree): Name = { field match { case q"$mods val $tname: $tpt = $expr" => tpt.asInstanceOf[Ident].name.decodedName } @@ -214,12 +180,10 @@ trait MacroCommon { /** * Check whether the mods of the fields has a `private[this]`, because it cannot be used in equals method. * - * @param c * @param field * @return */ - def classParamsIsPrivate(c: whitebox.Context)(field: c.universe.Tree): Boolean = { - import c.universe._ + def classParamsIsPrivate(field: Tree): Boolean = { field match { case q"$mods val $tname: $tpt = $expr" => if (mods.asInstanceOf[Modifiers].hasFlag(Flag.PRIVATE)) false else true case q"$mods var $tname: $tpt = $expr" => true @@ -229,12 +193,10 @@ trait MacroCommon { /** * Expand the constructor and get the field with assign. * - * @param c * @param annotteeClassParams * @return */ - def getFieldAssignExprs(c: whitebox.Context)(annotteeClassParams: Seq[c.Tree]): Seq[c.Tree] = { - import c.universe._ + def getFieldAssignExprs(annotteeClassParams: Seq[Tree]): Seq[Tree] = { annotteeClassParams.map { case q"$mods var $tname: $tpt = $expr" => q"$tname: $tpt" //Ignore expr case q"$mods val $tname: $tpt = $expr" => q"$tname: $tpt" @@ -244,16 +206,14 @@ trait MacroCommon { /** * Modify companion objects. * - * @param c * @param compDeclOpt * @param codeBlock * @param className * @return */ - def modifiedCompanion(c: whitebox.Context)( - compDeclOpt: Option[c.universe.ModuleDef], - codeBlock: c.Tree, className: c.TypeName): c.universe.Tree = { - import c.universe._ + def modifiedCompanion( + compDeclOpt: Option[ModuleDef], + codeBlock: Tree, className: TypeName): Tree = { compDeclOpt map { compDecl => val q"$mods object $obj extends ..$bases { ..$body }" = compDecl val o = @@ -276,11 +236,9 @@ trait MacroCommon { /** * Extract the internal fields of members belonging to the class, but not in primary constructor. * - * @param c * @param annotteeClassDefinitions */ - def getClassMemberValDefs(c: whitebox.Context)(annotteeClassDefinitions: Seq[c.Tree]): Seq[c.Tree] = { - import c.universe._ + def getClassMemberValDefs(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = { annotteeClassDefinitions.filter(p => p match { case _: ValDef => true case _ => false @@ -290,11 +248,9 @@ trait MacroCommon { /** * Extract the methods belonging to the class, contains Secondary Constructor. * - * @param c * @param annotteeClassDefinitions */ - def getClassMemberDefDefs(c: whitebox.Context)(annotteeClassDefinitions: Seq[c.Tree]): Seq[c.Tree] = { - import c.universe._ + def getClassMemberDefDefs(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = { annotteeClassDefinitions.filter(p => p match { case _: DefDef => true case _ => false @@ -310,9 +266,8 @@ trait MacroCommon { * @return A constructor with currying, it not contains tpt, provide for calling method. * @example [[new TestClass12(i)(j)(k)(t)]] */ - def getConstructorWithCurrying(c: whitebox.Context)(typeName: c.TypeName, fieldss: List[List[c.Tree]], isCase: Boolean): c.Tree = { - import c.universe._ - val allFieldsTermName = fieldss.map(f => f.map(ff => getFieldTermName(c)(ff))) + def getConstructorWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], isCase: Boolean): Tree = { + val allFieldsTermName = fieldss.map(f => f.map(ff => getFieldTermName(ff))) // not currying val constructor = if (fieldss.isEmpty || fieldss.size == 1) { q"${if (isCase) q"${typeName.toTermName}(..${allFieldsTermName.flatten})" else q"new $typeName(..${allFieldsTermName.flatten})"}" @@ -320,7 +275,7 @@ trait MacroCommon { // currying val first = allFieldsTermName.head if (isCase) q"${typeName.toTermName}(...$first)(...${allFieldsTermName.tail})" - else q"new $typeName(...$first)(...${allFieldsTermName.tail})" + else q"new $typeName(..$first)(...${allFieldsTermName.tail})" } c.info(c.enclosingPosition, s"getConstructorWithCurrying constructor: $constructor, paramss: $fieldss", force = true) constructor @@ -334,17 +289,16 @@ trait MacroCommon { * @return A apply method with currying. * @example [[def apply(int: Int)(j: Int)(k: Option[String])(t: Option[Long]): B3 = new B3(int)(j)(k)(t)]] */ - def getApplyMethodWithCurrying(c: whitebox.Context)(typeName: c.TypeName, fieldss: List[List[c.Tree]], classTypeParams: List[c.Tree]): c.Tree = { - import c.universe._ - val allFieldsTermName = fieldss.map(f => getFieldAssignExprs(c)(f)) - val returnTypeParams = extractClassTypeParamsTypeName(c)(classTypeParams) + def getApplyMethodWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], classTypeParams: List[Tree]): Tree = { + val allFieldsTermName = fieldss.map(f => getFieldAssignExprs(f)) + val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams) // not currying val applyMethod = if (fieldss.isEmpty || fieldss.size == 1) { - q"def apply[..$classTypeParams](..${allFieldsTermName.flatten}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(c)(typeName, fieldss, isCase = false)}" + q"def apply[..$classTypeParams](..${allFieldsTermName.flatten}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase = false)}" } else { // currying val first = allFieldsTermName.head - q"def apply[..$classTypeParams](..$first)(...${allFieldsTermName.tail}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(c)(typeName, fieldss, isCase = false)}" + q"def apply[..$classTypeParams](..$first)(...${allFieldsTermName.tail}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase = false)}" } c.info(c.enclosingPosition, s"getApplyWithCurrying constructor: $applyMethod, paramss: $fieldss", force = true) applyMethod @@ -375,14 +329,13 @@ trait MacroCommon { * Gets a list of generic parameters. * This is because the generic parameters of a class cannot be used directly in the return type, and need to be converted. * - * @param c * @param tpParams * @return */ - def extractClassTypeParamsTypeName(c: whitebox.Context)(tpParams: List[c.Tree]): List[c.TypeName] = { - import c.universe._ + def extractClassTypeParamsTypeName(tpParams: List[Tree]): List[TypeName] = { tpParams.map { case t: TypeDef => t.name } } + } diff --git a/src/main/scala/io/github/dreamylost/macros/applyMacro.scala b/src/main/scala/io/github/dreamylost/macros/applyMacro.scala index 6ce393c..eff6674 100644 --- a/src/main/scala/io/github/dreamylost/macros/applyMacro.scala +++ b/src/main/scala/io/github/dreamylost/macros/applyMacro.scala @@ -29,45 +29,46 @@ import scala.reflect.macros.whitebox * @since 2021/7/7 * @version 1.0 */ -object applyMacro extends MacroCommon { - def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { - import c.universe._ - val args: Tuple1[Boolean] = extractArgumentsTuple1(c) { - case q"new apply(verbose=$verbose)" => Tuple1(evalTree(c)(verbose.asInstanceOf[Tree])) - case q"new apply()" => Tuple1(false) - case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) - } - - c.info(c.enclosingPosition, s"annottees: $annottees, args: $args", force = args._1) +object applyMacro { - val annotateeClass: ClassDef = checkAndGetClassDef(c)(annottees: _*) - val isCase: Boolean = isCaseClass(c)(annotateeClass) - c.info(c.enclosingPosition, s"impl argument: $args, isCase: $isCase", force = args._1) + class applyProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) { - if (isCase) c.abort(c.enclosingPosition, s"Annotation is only supported on 'case class'") + import c.universe._ - def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { - val (className, classParams, classTypeParams) = classDecl match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends ..$bases { ..$body }" => - c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = args._1) - (tpname, paramss.asInstanceOf[List[List[Tree]]], tparams.asInstanceOf[List[Tree]]) - case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") + override def impl(annottees: Expr[Any]*): Expr[Any] = { + val args: Tuple1[Boolean] = extractArgumentsTuple1 { + case q"new apply(verbose=$verbose)" => Tuple1(evalTree(verbose.asInstanceOf[Tree])) + case q"new apply()" => Tuple1(false) + case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) } - c.info(c.enclosingPosition, s"modifiedDeclaration compDeclOpt: $compDeclOpt, annotteeClassParams: $classParams", force = args._1) - val tpName = className.asInstanceOf[TypeName] - val apply = getApplyMethodWithCurrying(c)(tpName, classParams, classTypeParams) - val compDecl = modifiedCompanion(c)(compDeclOpt, apply, tpName) - c.Expr( - q""" + val annotateeClass: ClassDef = checkAndGetClassDef(annottees: _*) + val isCase: Boolean = isCaseClass(annotateeClass) + c.info(c.enclosingPosition, s"impl argument: $args, isCase: $isCase", force = args._1) + + if (isCase) c.abort(c.enclosingPosition, s"Annotation is only supported on 'case class'") + + def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { + val (className, classParams, classTypeParams) = classDecl match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends ..$bases { ..$body }" => + c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = args._1) + (tpname, paramss.asInstanceOf[List[List[Tree]]], tparams.asInstanceOf[List[Tree]]) + case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") + } + c.info(c.enclosingPosition, s"modifiedDeclaration compDeclOpt: $compDeclOpt, annotteeClassParams: $classParams", force = args._1) + val tpName = className.asInstanceOf[TypeName] + val apply = getApplyMethodWithCurrying(tpName, classParams, classTypeParams) + val compDecl = modifiedCompanion(compDeclOpt, apply, tpName) + c.Expr( + q""" $classDecl $compDecl """) - } - - val resTree = handleWithImplType(c)(annottees: _*)(modifiedDeclaration) - printTree(c)(force = args._1, resTree.tree) + } - resTree + val resTree = handleWithImplType(annottees: _*)(modifiedDeclaration) + printTree(force = args._1, resTree.tree) + resTree + } } } diff --git a/src/main/scala/io/github/dreamylost/macros/builderMacro.scala b/src/main/scala/io/github/dreamylost/macros/builderMacro.scala index 83c4108..7500247 100644 --- a/src/main/scala/io/github/dreamylost/macros/builderMacro.scala +++ b/src/main/scala/io/github/dreamylost/macros/builderMacro.scala @@ -29,51 +29,54 @@ import scala.reflect.macros.whitebox * @since 2021/7/7 * @version 1.0 */ -object builderMacro extends MacroCommon { +object builderMacro { + private final val BUFFER_CLASS_NAME_SUFFIX = "Builder" - def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + class BuilderProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) { + import c.universe._ - def getBuilderClassName(classTree: TypeName): TypeName = { - TypeName(classTree.toTermName.decodedName.toString + BUFFER_CLASS_NAME_SUFFIX) - } + override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = { + def getBuilderClassName(classTree: TypeName): TypeName = { + TypeName(classTree.toTermName.decodedName.toString + BUFFER_CLASS_NAME_SUFFIX) + } - def fieldSetMethod(typeName: TypeName, field: Tree, classTypeParams: List[Tree]): c.Tree = { - val builderClassName = getBuilderClassName(typeName) - val returnTypeParams = extractClassTypeParamsTypeName(c)(classTypeParams) - field match { - case q"$mods var $tname: $tpt = $expr" => - q""" + def fieldSetMethod(typeName: TypeName, field: Tree, classTypeParams: List[Tree]): Tree = { + val builderClassName = getBuilderClassName(typeName) + val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams) + field match { + case q"$mods var $tname: $tpt = $expr" => + q""" def $tname($tname: $tpt): $builderClassName[..$returnTypeParams] = { this.$tname = $tname this } """ - case q"$mods val $tname: $tpt = $expr" => - q""" + case q"$mods val $tname: $tpt = $expr" => + q""" def $tname($tname: $tpt): $builderClassName[..$returnTypeParams] = { this.$tname = $tname this } """ + } } - } - def fieldDefinition(field: Tree): Tree = { - field match { - case q"$mods val $tname: $tpt = $expr" => q"""private var $tname: $tpt = $expr""" - case q"$mods var $tname: $tpt = $expr" => q"""private var $tname: $tpt = $expr""" + def fieldDefinition(field: Tree): Tree = { + field match { + case q"$mods val $tname: $tpt = $expr" => q"""private var $tname: $tpt = $expr""" + case q"$mods var $tname: $tpt = $expr" => q"""private var $tname: $tpt = $expr""" + } } - } - def builderTemplate(typeName: TypeName, fieldss: List[List[Tree]], classTypeParams: List[Tree], isCase: Boolean): Tree = { - val fields = fieldss.flatten - val builderClassName = getBuilderClassName(typeName) - val builderFieldMethods = fields.map(f => fieldSetMethod(typeName, f, classTypeParams)) - val builderFieldDefinitions = fields.map(f => fieldDefinition(f)) - val returnTypeParams = extractClassTypeParamsTypeName(c)(classTypeParams) - q""" + def builderTemplate(typeName: TypeName, fieldss: List[List[Tree]], classTypeParams: List[Tree], isCase: Boolean): Tree = { + val fields = fieldss.flatten + val builderClassName = getBuilderClassName(typeName) + val builderFieldMethods = fields.map(f => fieldSetMethod(typeName, f, classTypeParams)) + val builderFieldDefinitions = fields.map(f => fieldDefinition(f)) + val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams) + q""" def builder[..$classTypeParams](): $builderClassName[..$returnTypeParams] = new $builderClassName() class $builderClassName[..$classTypeParams] { @@ -82,40 +85,39 @@ object builderMacro extends MacroCommon { ..$builderFieldMethods - def build(): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(c)(typeName, fieldss, isCase)} + def build(): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase)} } """ - } - - // Why use Any? The dependent type need aux-pattern in scala2. Now let's get around this. - def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { - val (className, fieldss, classTypeParams) = classDecl match { - // @see https://scala-lang.org/files/archive/spec/2.13/05-classes-and-objects.html - case q"$mods class $tpname[..$tparams](...$paramss) extends ..$bases { ..$body }" => - c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = true) - (tpname, paramss.asInstanceOf[List[List[Tree]]], tparams.asInstanceOf[List[Tree]]) - case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") } - c.info(c.enclosingPosition, s"modifiedDeclaration compDeclOpt: $compDeclOpt, fieldss: $fieldss", force = true) - - val cName = className.asInstanceOf[TypeName] - val isCase = isCaseClass(c)(classDecl) - val builder = builderTemplate(cName, fieldss, classTypeParams, isCase) - val compDecl = modifiedCompanion(c)(compDeclOpt, builder, cName) - c.info(c.enclosingPosition, s"builderTree: $builder, compDecl: $compDecl", force = true) - // Return both the class and companion object declarations - c.Expr( - q""" + + // Why use Any? The dependent type need aux-pattern in scala2. Now let's get around this. + def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { + val (className, fieldss, classTypeParams) = classDecl match { + // @see https://scala-lang.org/files/archive/spec/2.13/05-classes-and-objects.html + case q"$mods class $tpname[..$tparams](...$paramss) extends ..$bases { ..$body }" => + c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = true) + (tpname, paramss.asInstanceOf[List[List[Tree]]], tparams.asInstanceOf[List[Tree]]) + case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") + } + + val cName = className.asInstanceOf[TypeName] + val isCase = isCaseClass(classDecl) + val builder = builderTemplate(cName, fieldss, classTypeParams, isCase) + val compDecl = modifiedCompanion(compDeclOpt, builder, cName) + c.info(c.enclosingPosition, s"builderTree: $builder, compDecl: $compDecl", force = true) + // Return both the class and companion object declarations + c.Expr( + q""" $classDecl $compDecl """) - } - - c.info(c.enclosingPosition, s"builder annottees: $annottees", force = true) + } - val resTree = handleWithImplType(c)(annottees: _*)(modifiedDeclaration) - printTree(c)(force = true, resTree.tree) + val resTree = handleWithImplType(annottees: _*)(modifiedDeclaration) + printTree(force = true, resTree.tree) - resTree + resTree + } } + } diff --git a/src/main/scala/io/github/dreamylost/macros/constructorMacro.scala b/src/main/scala/io/github/dreamylost/macros/constructorMacro.scala index 76c0caa..270021b 100644 --- a/src/main/scala/io/github/dreamylost/macros/constructorMacro.scala +++ b/src/main/scala/io/github/dreamylost/macros/constructorMacro.scala @@ -29,109 +29,111 @@ import scala.reflect.macros.whitebox * @since 2021/7/7 * @version 1.0 */ -object constructorMacro extends MacroCommon { - def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { - import c.universe._ - val args: (Boolean, Seq[String]) = extractArgumentsTuple2(c) { - case q"new constructor(verbose=$verbose)" => (evalTree(c)(verbose.asInstanceOf[Tree]), Nil) - case q"new constructor(excludeFields=$excludeFields)" => (false, evalTree(c)(excludeFields.asInstanceOf[Tree])) - case q"new constructor(verbose=$verbose, excludeFields=$excludeFields)" => (evalTree(c)(verbose.asInstanceOf[Tree]), evalTree(c)(excludeFields.asInstanceOf[Tree])) - case q"new constructor()" => (false, Nil) - case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) - } +object constructorMacro { - val annotateeClass: ClassDef = checkAndGetClassDef(c)(annottees: _*) - val isCase: Boolean = isCaseClass(c)(annotateeClass) - if (isCase) { - c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $annotateeClass") - } + class ConstructorProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) { - c.info(c.enclosingPosition, s"annottees: $annottees, annotateeClass: $annotateeClass", args._1) + import c.universe._ - def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { - val (annotteeClassParams, annotteeClassDefinitions) = classDecl match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => - c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = args._1) - (paramss.asInstanceOf[List[List[Tree]]], stats.asInstanceOf[Seq[Tree]]) - case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") + override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = { + val args: (Boolean, Seq[String]) = extractArgumentsTuple2 { + case q"new constructor(verbose=$verbose)" => (evalTree(verbose.asInstanceOf[Tree]), Nil) + case q"new constructor(excludeFields=$excludeFields)" => (false, evalTree(excludeFields.asInstanceOf[Tree])) + case q"new constructor(verbose=$verbose, excludeFields=$excludeFields)" => (evalTree(verbose.asInstanceOf[Tree]), evalTree(excludeFields.asInstanceOf[Tree])) + case q"new constructor()" => (false, Nil) + case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) } - // Extract the internal fields of members belonging to the class, but not in primary constructor. - val classFieldDefinitions = getClassMemberValDefs(c)(annotteeClassDefinitions) - val excludeFields = args._2 + val annotateeClass: ClassDef = checkAndGetClassDef(annottees: _*) + val isCase: Boolean = isCaseClass(annotateeClass) + if (isCase) { + c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $annotateeClass") + } - /** - * Extract the internal fields of members belonging to the class, but not in primary constructor and only `var`. - */ - def getClassMemberVarDefOnlyAssignExpr: Seq[c.Tree] = { - import c.universe._ - getClassMemberValDefs(c)(annotteeClassDefinitions).filter(_ match { - case q"$mods var $tname: $tpt = $expr" if !excludeFields.contains(tname.asInstanceOf[TermName].decodedName.toString) => true - case _ => false - }).map { - case q"$mods var $pat = $expr" => - // TODO getClass RETURN a java type, maybe we can try use class reflect to get the fields type name. - q"$pat: ${TypeName(toScalaType(evalTree(c)(expr.asInstanceOf[Tree]).getClass.getTypeName))}" - case q"$mods var $tname: $tpt = $expr" => q"$tname: $tpt" + def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { + val (annotteeClassParams, annotteeClassDefinitions) = classDecl match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => + c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = args._1) + (paramss.asInstanceOf[List[List[Tree]]], stats.asInstanceOf[Seq[Tree]]) + case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") } - } - val classFieldDefinitionsOnlyAssignExpr = getClassMemberVarDefOnlyAssignExpr + // Extract the internal fields of members belonging to the class, but not in primary constructor. + val classFieldDefinitions = getClassMemberValDefs(annotteeClassDefinitions) + val excludeFields = args._2 + + /** + * Extract the internal fields of members belonging to the class, but not in primary constructor and only `var`. + */ + def getClassMemberVarDefOnlyAssignExpr: Seq[Tree] = { + import c.universe._ + getClassMemberValDefs(annotteeClassDefinitions).filter(_ match { + case q"$mods var $tname: $tpt = $expr" if !excludeFields.contains(tname.asInstanceOf[TermName].decodedName.toString) => true + case _ => false + }).map { + case q"$mods var $pat = $expr" => + // TODO getClass RETURN a java type, maybe we can try use class reflect to get the fields type name. + q"$pat: ${TypeName(toScalaType(evalTree(expr.asInstanceOf[Tree]).getClass.getTypeName))}" + case q"$mods var $tname: $tpt = $expr" => q"$tname: $tpt" + } + } - if (classFieldDefinitionsOnlyAssignExpr.isEmpty) { - c.abort(c.enclosingPosition, s"Annotation is only supported on class when the internal field (declare as 'var') is nonEmpty. classDef: $classDecl") - } + val classFieldDefinitionsOnlyAssignExpr = getClassMemberVarDefOnlyAssignExpr - val annotteeClassFieldNames = classFieldDefinitions.filter(_ match { - case q"$mods var $tname: $tpt = $expr" if !excludeFields.contains(tname.asInstanceOf[TermName].decodedName.toString) => true - case _ => false - }).map { - case q"$mods var $tname: $tpt = $expr" => tname.asInstanceOf[TermName] - } + if (classFieldDefinitionsOnlyAssignExpr.isEmpty) { + c.abort(c.enclosingPosition, s"Annotation is only supported on class when the internal field (declare as 'var') is nonEmpty. classDef: $classDecl") + } - c.info(c.enclosingPosition, s"modifiedDeclaration compDeclOpt: $compDeclOpt, annotteeClassParams: $annotteeClassParams", force = args._1) + val annotteeClassFieldNames = classFieldDefinitions.filter(_ match { + case q"$mods var $tname: $tpt = $expr" if !excludeFields.contains(tname.asInstanceOf[TermName].decodedName.toString) => true + case _ => false + }).map { + case q"$mods var $tname: $tpt = $expr" => tname.asInstanceOf[TermName] + } - // Extract the field of the primary constructor. - val allFieldsTermName = annotteeClassParams.map(f => f.map(ff => getFieldTermName(c)(ff))) + c.info(c.enclosingPosition, s"modifiedDeclaration compDeclOpt: $compDeclOpt, annotteeClassParams: $annotteeClassParams", force = args._1) - /** - * We generate this method with currying, and we have to deal with the first layer of currying alone. - */ - def getThisMethodWithCurrying: c.Tree = { - // not currying // Extract the field of the primary constructor. - val classParamsAssignExpr = getFieldAssignExprs(c)(annotteeClassParams.flatten) - val applyMethod = if (annotteeClassParams.isEmpty || annotteeClassParams.size == 1) { - q""" + val allFieldsTermName = annotteeClassParams.map(f => f.map(ff => getFieldTermName(ff))) + + /** + * We generate this method with currying, and we have to deal with the first layer of currying alone. + */ + def getThisMethodWithCurrying: Tree = { + // not currying + // Extract the field of the primary constructor. + val classParamsAssignExpr = getFieldAssignExprs(annotteeClassParams.flatten) + val applyMethod = if (annotteeClassParams.isEmpty || annotteeClassParams.size == 1) { + q""" def this(..${classParamsAssignExpr ++ classFieldDefinitionsOnlyAssignExpr}) = { this(..${allFieldsTermName.flatten}) ..${annotteeClassFieldNames.map(f => q"this.$f = $f")} } """ - } else { - // NOTE: currying constructor overload must be placed in the first bracket block. - val allClassParamsAssignExpr = annotteeClassParams.map(cc => getFieldAssignExprs(c)(cc)) - q""" + } else { + // NOTE: currying constructor overload must be placed in the first bracket block. + val allClassParamsAssignExpr = annotteeClassParams.map(cc => getFieldAssignExprs(cc)) + q""" def this(..${allClassParamsAssignExpr.head ++ classFieldDefinitionsOnlyAssignExpr})(...${allClassParamsAssignExpr.tail}) = { this(..${allFieldsTermName.head})(...${allFieldsTermName.tail}) ..${annotteeClassFieldNames.map(f => q"this.$f = $f")} } """ + } + applyMethod } - applyMethod - } - val resTree = annotateeClass match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => - q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${stats.toList.:+(getThisMethodWithCurrying)} }" + val resTree = annotateeClass match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => + q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${stats.toList.:+(getThisMethodWithCurrying)} }" + } + c.Expr[Any](treeResultWithCompanionObject(resTree, annottees: _*)) } - c.Expr[Any](treeResultWithCompanionObject(c)(resTree, annottees: _*)) - } - - val resTree = handleWithImplType(c)(annottees: _*)(modifiedDeclaration) - printTree(c)(force = args._1, resTree.tree) - resTree + val resTree = handleWithImplType(annottees: _*)(modifiedDeclaration) + printTree(force = args._1, resTree.tree) + resTree + } } } diff --git a/src/main/scala/io/github/dreamylost/macros/equalsAndHashCodeMacro.scala b/src/main/scala/io/github/dreamylost/macros/equalsAndHashCodeMacro.scala index eac8451..8b37f8c 100644 --- a/src/main/scala/io/github/dreamylost/macros/equalsAndHashCodeMacro.scala +++ b/src/main/scala/io/github/dreamylost/macros/equalsAndHashCodeMacro.scala @@ -29,73 +29,76 @@ import scala.reflect.macros.whitebox * @since 2021/7/18 * @version 1.0 */ -object equalsAndHashCodeMacro extends MacroCommon { +object equalsAndHashCodeMacro { + + class EqualsAndHashCodeProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) { - def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ - val args: (Boolean, Seq[String]) = extractArgumentsTuple2(c) { - case q"new equalsAndHashCode(verbose=$verbose)" => (evalTree(c)(verbose.asInstanceOf[Tree]), Nil) - case q"new equalsAndHashCode(excludeFields=$excludeFields)" => (false, evalTree(c)(excludeFields.asInstanceOf[Tree])) - case q"new equalsAndHashCode(verbose=$verbose, excludeFields=$excludeFields)" => (evalTree(c)(verbose.asInstanceOf[Tree]), evalTree(c)(excludeFields.asInstanceOf[Tree])) - case q"new equalsAndHashCode()" => (false, Nil) - case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) - } - val annotateeClass: ClassDef = checkAndGetClassDef(c)(annottees: _*) - val isCase: Boolean = isCaseClass(c)(annotateeClass) - if (isCase) { - c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $annotateeClass") - } - val excludeFields = args._2 - - def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { - val (className, annotteeClassParams, annotteeClassDefinitions, superClasses) = classDecl match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => - c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = args._1) - (tpname, paramss.asInstanceOf[List[List[Tree]]], stats.asInstanceOf[Seq[Tree]], parents) - case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") - } - val ctorFieldNames = annotteeClassParams.flatten.filter(cf => classParamsIsPrivate(c)(cf)) - val allFieldsTermName = ctorFieldNames.map(f => getFieldTermName(c)(f)) - - c.info(c.enclosingPosition, s"modifiedDeclaration compDeclOpt: $compDeclOpt, ctorFieldNames: $ctorFieldNames, " + - s"annotteeClassParams: $superClasses", force = args._1) - - /** - * Extract the internal fields of members belonging to the class. - */ - def getClassMemberAllTermName: Seq[c.TermName] = { - getClassMemberValDefs(c)(annotteeClassDefinitions).filter(_ match { - case q"$mods var $tname: $tpt = $expr" if !excludeFields.contains(tname.asInstanceOf[TermName].decodedName.toString) => true - case q"$mods val $tname: $tpt = $expr" if !excludeFields.contains(tname.asInstanceOf[TermName].decodedName.toString) => true - case q"$mods val $pat = $expr" if !excludeFields.contains(pat.asInstanceOf[TermName].decodedName.toString) => true - case q"$mods var $pat = $expr" if !excludeFields.contains(pat.asInstanceOf[TermName].decodedName.toString) => true - case _ => false - }).map(f => getFieldTermName(c)(f)) + override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = { + val args: (Boolean, Seq[String]) = extractArgumentsTuple2 { + case q"new equalsAndHashCode(verbose=$verbose)" => (evalTree(verbose.asInstanceOf[Tree]), Nil) + case q"new equalsAndHashCode(excludeFields=$excludeFields)" => (false, evalTree(excludeFields.asInstanceOf[Tree])) + case q"new equalsAndHashCode(verbose=$verbose, excludeFields=$excludeFields)" => (evalTree(verbose.asInstanceOf[Tree]), evalTree(excludeFields.asInstanceOf[Tree])) + case q"new equalsAndHashCode()" => (false, Nil) + case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) } - val existsCanEqual = getClassMemberDefDefs(c)(annotteeClassDefinitions) exists { - case q"$mods def $tname[..$tparams](...$paramss): $tpt = $expr" if tname.toString() == "canEqual" && paramss.nonEmpty => - val params = paramss.asInstanceOf[List[List[Tree]]].flatten.map(pp => getMethodParamName(c)(pp)) - params.exists(p => p.decodedName.toString == "Any") - case _ => false + val annotateeClass: ClassDef = checkAndGetClassDef(annottees: _*) + val isCase: Boolean = isCaseClass(annotateeClass) + if (isCase) { + c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $annotateeClass") } + val excludeFields = args._2 + + def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { + val (className, annotteeClassParams, annotteeClassDefinitions, superClasses) = classDecl match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => + c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = args._1) + (tpname, paramss.asInstanceOf[List[List[Tree]]], stats.asInstanceOf[Seq[Tree]], parents) + case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") + } + val ctorFieldNames = annotteeClassParams.flatten.filter(cf => classParamsIsPrivate(cf)) + val allFieldsTermName = ctorFieldNames.map(f => getFieldTermName(f)) + + c.info(c.enclosingPosition, s"modifiedDeclaration compDeclOpt: $compDeclOpt, ctorFieldNames: $ctorFieldNames, " + + s"annotteeClassParams: $superClasses", force = args._1) + + /** + * Extract the internal fields of members belonging to the class. + */ + def getClassMemberAllTermName: Seq[TermName] = { + getClassMemberValDefs(annotteeClassDefinitions).filter(_ match { + case q"$mods var $tname: $tpt = $expr" if !excludeFields.contains(tname.asInstanceOf[TermName].decodedName.toString) => true + case q"$mods val $tname: $tpt = $expr" if !excludeFields.contains(tname.asInstanceOf[TermName].decodedName.toString) => true + case q"$mods val $pat = $expr" if !excludeFields.contains(pat.asInstanceOf[TermName].decodedName.toString) => true + case q"$mods var $pat = $expr" if !excludeFields.contains(pat.asInstanceOf[TermName].decodedName.toString) => true + case _ => false + }).map(f => getFieldTermName(f)) + } - // + super.hashCode - val SDKClasses = Set("java.lang.Object", "scala.AnyRef") - val canEqualsExistsInSuper = if (superClasses.nonEmpty && !superClasses.forall(sc => SDKClasses.contains(sc.toString()))) { // TODO better way - true - } else false - - // equals template - def ==(termNames: Seq[TermName]): c.universe.Tree = { - val getEqualsExpr = (termName: TermName) => { - q"this.$termName.equals(t.$termName)" + val existsCanEqual = getClassMemberDefDefs(annotteeClassDefinitions) exists { + case q"$mods def $tname[..$tparams](...$paramss): $tpt = $expr" if tname.toString() == "canEqual" && paramss.nonEmpty => + val params = paramss.asInstanceOf[List[List[Tree]]].flatten.map(pp => getMethodParamName(pp)) + params.exists(p => p.decodedName.toString == "Any") + case _ => false } - val equalsExprs = termNames.map(getEqualsExpr) - val modifiers = if (canEqualsExistsInSuper) Modifiers(Flag.OVERRIDE, typeNames.EMPTY, List()) else Modifiers(NoFlags, typeNames.EMPTY, List()) - val canEqual = if (existsCanEqual) q"" else q"$modifiers def canEqual(that: Any) = that.isInstanceOf[$className]" - q""" + + // + super.hashCode + val SDKClasses = Set("java.lang.Object", "scala.AnyRef") + val canEqualsExistsInSuper = if (superClasses.nonEmpty && !superClasses.forall(sc => SDKClasses.contains(sc.toString()))) { // TODO better way + true + } else false + + // equals template + def ==(termNames: Seq[TermName]): Tree = { + val getEqualsExpr = (termName: TermName) => { + q"this.$termName.equals(t.$termName)" + } + val equalsExprs = termNames.map(getEqualsExpr) + val modifiers = if (canEqualsExistsInSuper) Modifiers(Flag.OVERRIDE, typeNames.EMPTY, List()) else Modifiers(NoFlags, typeNames.EMPTY, List()) + val canEqual = if (existsCanEqual) q"" else q"$modifiers def canEqual(that: Any) = that.isInstanceOf[$className]" + q""" $canEqual override def equals(that: Any): Boolean = @@ -104,54 +107,55 @@ object equalsAndHashCodeMacro extends MacroCommon { case _ => false } """ - } + } - // hashcode template - def ##(termNames: Seq[TermName]): c.universe.Tree = { - // the algorithm see https://alvinalexander.com/scala/how-to-define-equals-hashcode-methods-in-scala-object-equality/ - // We use default 1. - if (!canEqualsExistsInSuper) { - q""" + // hashcode template + def ##(termNames: Seq[TermName]): Tree = { + // the algorithm see https://alvinalexander.com/scala/how-to-define-equals-hashcode-methods-in-scala-object-equality/ + // We use default 1. + if (!canEqualsExistsInSuper) { + q""" override def hashCode(): Int = { val state = Seq(..$termNames) state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) } """ - } else { - q""" + } else { + q""" override def hashCode(): Int = { val state = Seq(..$termNames) state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + super.hashCode } """ + } } - } - val allTernNames = allFieldsTermName ++ getClassMemberAllTermName - val hashcode = ##(allTernNames) - val equals = ==(allTernNames) - val equalsAndHashcode = - q""" + val allTernNames = allFieldsTermName ++ getClassMemberAllTermName + val hashcode = ##(allTernNames) + val equals = ==(allTernNames) + val equalsAndHashcode = + q""" ..$equals $hashcode """ - // return with object if it exists - val resTree = annotateeClass match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => - val originalStatus = q"{ ..$stats }" - val append = - q""" + // return with object if it exists + val resTree = annotateeClass match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => + val originalStatus = q"{ ..$stats }" + val append = + q""" ..$originalStatus ..$equalsAndHashcode """ - q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${append} }" + q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${append} }" + } + c.Expr[Any](treeResultWithCompanionObject(resTree, annottees: _*)) } - c.Expr[Any](treeResultWithCompanionObject(c)(resTree, annottees: _*)) - } - val resTree = handleWithImplType(c)(annottees: _*)(modifiedDeclaration) - printTree(c)(force = args._1, resTree.tree) + val resTree = handleWithImplType(annottees: _*)(modifiedDeclaration) + printTree(force = args._1, resTree.tree) - resTree + resTree + } } } diff --git a/src/main/scala/io/github/dreamylost/macros/jsonMacro.scala b/src/main/scala/io/github/dreamylost/macros/jsonMacro.scala index 3c64bfe..39d8670 100644 --- a/src/main/scala/io/github/dreamylost/macros/jsonMacro.scala +++ b/src/main/scala/io/github/dreamylost/macros/jsonMacro.scala @@ -29,49 +29,52 @@ import scala.reflect.macros.whitebox * @since 2021/7/7 * @version 1.0 */ -object jsonMacro extends MacroCommon { - def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { +object jsonMacro { + + class JsonProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) { + import c.universe._ - def jsonFormatter(className: TypeName, fields: List[Tree]): c.universe.Tree = { - fields.length match { - case 0 => c.abort(c.enclosingPosition, "Cannot create json formatter for case class with no fields") - case _ => - c.info(c.enclosingPosition, s"jsonFormatter className: $className, field length: ${fields.length}", force = true) - q"implicit val jsonAnnotationFormat = play.api.libs.json.Json.format[$className]" + override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = { + def jsonFormatter(className: TypeName, fields: List[Tree]): Tree = { + fields.length match { + case 0 => c.abort(c.enclosingPosition, "Cannot create json formatter for case class with no fields") + case _ => + c.info(c.enclosingPosition, s"jsonFormatter className: $className, field length: ${fields.length}", force = true) + q"implicit val jsonAnnotationFormat = play.api.libs.json.Json.format[$className]" + } } - } - // The dependent type need aux-pattern in scala2. Now let's get around this. - def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { - val (className, fields) = classDecl match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends ..$bases { ..$body }" => - if (!mods.asInstanceOf[Modifiers].hasFlag(Flag.CASE)) { - c.abort(c.enclosingPosition, s"Annotation is only supported on case class. classDef: $classDecl, mods: $mods") - } else { - c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = true) - (tpname, paramss.asInstanceOf[List[List[Tree]]]) - } - case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") - } - c.info(c.enclosingPosition, s"modifiedDeclaration className: $className, fields: $fields", force = true) - val cName = className.asInstanceOf[TypeName] - val format = jsonFormatter(cName, fields.flatten) - val compDecl = modifiedCompanion(c)(compDeclOpt, format, cName) - c.info(c.enclosingPosition, s"format: $format, compDecl: $compDecl", force = true) - // Return both the class and companion object declarations - c.Expr( - q""" + // The dependent type need aux-pattern in scala2. Now let's get around this. + def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = { + val (className, fields) = classDecl match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends ..$bases { ..$body }" => + if (!mods.asInstanceOf[Modifiers].hasFlag(Flag.CASE)) { + c.abort(c.enclosingPosition, s"Annotation is only supported on case class. classDef: $classDecl, mods: $mods") + } else { + c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = true) + (tpname, paramss.asInstanceOf[List[List[Tree]]]) + } + case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl") + } + c.info(c.enclosingPosition, s"modifiedDeclaration className: $className, fields: $fields", force = true) + val cName = className.asInstanceOf[TypeName] + val format = jsonFormatter(cName, fields.flatten) + val compDecl = modifiedCompanion(compDeclOpt, format, cName) + c.info(c.enclosingPosition, s"format: $format, compDecl: $compDecl", force = true) + // Return both the class and companion object declarations + c.Expr( + q""" $classDecl $compDecl """) - } + } - c.info(c.enclosingPosition, s"json annottees: $annottees", force = true) - val resTree = handleWithImplType(c)(annottees: _*)(modifiedDeclaration) - printTree(c)(force = true, resTree.tree) + val resTree = handleWithImplType(annottees: _*)(modifiedDeclaration) + printTree(force = true, resTree.tree) - resTree + resTree + } } } diff --git a/src/main/scala/io/github/dreamylost/macros/logMacro.scala b/src/main/scala/io/github/dreamylost/macros/logMacro.scala index 7c59292..190d894 100644 --- a/src/main/scala/io/github/dreamylost/macros/logMacro.scala +++ b/src/main/scala/io/github/dreamylost/macros/logMacro.scala @@ -33,55 +33,60 @@ import scala.reflect.macros.whitebox * @since 2021/7/7 * @version 1.0 */ -object logMacro extends MacroCommon { - def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { +object logMacro { + + class LogProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) { + import c.universe._ - def getLogType(logType: c.Tree): LogType = { - if (logType.children.exists(t => t.toString().contains(PACKAGE))) { - evalTree(c)(logType.asInstanceOf[Tree]) // TODO remove asInstanceOf - } else { - LogType.getLogType(logType.toString()) + + override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = { + def getLogType(logType: Tree): LogType = { + if (logType.children.exists(t => t.toString().contains(PACKAGE))) { + evalTree(logType.asInstanceOf[Tree]) // TODO remove asInstanceOf + } else { + LogType.getLogType(logType.toString()) + } + } + val args: (Boolean, LogType) = extractArgumentsTuple2 { + case q"new log(logType=$logType)" => + val tpe = getLogType(logType.asInstanceOf[Tree]) + (false, tpe) + case q"new log(verbose=$verbose)" => (evalTree(verbose.asInstanceOf[Tree]), LogType.JLog) + case q"new log($logType)" => + val tpe = getLogType(logType.asInstanceOf[Tree]) + (false, tpe) + case q"new log(verbose=$verbose, logType=$logType)" => + val tpe = getLogType(logType.asInstanceOf[Tree]) + (evalTree(verbose.asInstanceOf[Tree]), tpe) + case q"new log()" => (false, LogType.JLog) + case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) } - } - val args: (Boolean, LogType) = extractArgumentsTuple2(c) { - case q"new log(logType=$logType)" => - val tpe = getLogType(logType.asInstanceOf[Tree]) - (false, tpe) - case q"new log(verbose=$verbose)" => (evalTree(c)(verbose.asInstanceOf[Tree]), LogType.JLog) - case q"new log($logType)" => - val tpe = getLogType(logType.asInstanceOf[Tree]) - (false, tpe) - case q"new log(verbose=$verbose, logType=$logType)" => - val tpe = getLogType(logType.asInstanceOf[Tree]) - (evalTree(c)(verbose.asInstanceOf[Tree]), tpe) - case q"new log()" => (false, LogType.JLog) - case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) - } - c.info(c.enclosingPosition, s"annottees: $annottees, args: $args", force = args._1) + c.info(c.enclosingPosition, s"annottees: $annottees, args: $args", force = args._1) - val logTree = annottees.map(_.tree) match { - // Match a class, and expand, get class/object name. - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ => - LogType.getLogImpl(args._2).getTemplate(c)(tpname.asInstanceOf[TypeName].toTermName.decodedName.toString, isClass = true) - case q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ => - LogType.getLogImpl(args._2).getTemplate(c)(tpname.asInstanceOf[TermName].decodedName.toString, isClass = false) - case _ => c.abort(c.enclosingPosition, s"Annotation is only supported on class or object.") - } + val logTree = annottees.map(_.tree) match { + // Match a class, and expand, get class/object name. + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ => + LogType.getLogImpl(args._2).getTemplate(c)(tpname.asInstanceOf[TypeName].toTermName.decodedName.toString, isClass = true) + case q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ => + LogType.getLogImpl(args._2).getTemplate(c)(tpname.asInstanceOf[TermName].decodedName.toString, isClass = false) + case _ => c.abort(c.enclosingPosition, s"Annotation is only supported on class or object.") + } - // add result into class - val resTree = annottees.map(_.tree) match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ => - val resTree = q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${List(logTree) ::: stats.toList} }" - treeResultWithCompanionObject(c)(resTree, annottees: _*) //we should return with companion object. Even if we didn't change it. - case q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ => - q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..${List(logTree) ::: stats.toList} }" - // Note: If a class is annotated and it has a companion, then both are passed into the macro. - // (But not vice versa - if an object is annotated and it has a companion class, only the object itself is expanded). - // see https://docs.scala-lang.org/overviews/macros/annotations.html - } + // add result into class + val resTree = annottees.map(_.tree) match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ => + q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${List(logTree) ::: stats.toList} }" + case q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ => + q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..${List(logTree) ::: stats.toList} }" + // Note: If a class is annotated and it has a companion, then both are passed into the macro. + // (But not vice versa - if an object is annotated and it has a companion class, only the object itself is expanded). + // see https://docs.scala-lang.org/overviews/macros/annotations.html + } - printTree(c)(force = args._1, resTree) - c.Expr[Any](resTree) + val res = treeResultWithCompanionObject(resTree, annottees: _*) + printTree(force = args._1, res) + c.Expr[Any](resTree) + } } } diff --git a/src/main/scala/io/github/dreamylost/macros/synchronizedMacro.scala b/src/main/scala/io/github/dreamylost/macros/synchronizedMacro.scala index 641a253..bb1f47c 100644 --- a/src/main/scala/io/github/dreamylost/macros/synchronizedMacro.scala +++ b/src/main/scala/io/github/dreamylost/macros/synchronizedMacro.scala @@ -29,34 +29,37 @@ import scala.reflect.macros.whitebox * @since 2021/7/7 * @version 1.0 */ -object synchronizedMacro extends MacroCommon { - def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { - import c.universe._ +object synchronizedMacro { - val args: (Boolean, String) = extractArgumentsTuple2(c) { - case q"new synchronized(verbose=$verbose, lockedName=$lock)" => (evalTree(c)(verbose.asInstanceOf[Tree]), evalTree(c)(lock.asInstanceOf[Tree])) - case q"new synchronized(lockedName=$lock)" => (false, evalTree(c)(lock.asInstanceOf[Tree])) - case q"new synchronized()" => (false, "this") - case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) - } + class SynchronizedProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) { - c.info(c.enclosingPosition, s"annottees: $annottees", force = args._1) + import c.universe._ - val resTree = annottees map (_.tree) match { - // Match a method, and expand. - case _@ q"$modrs def $tname[..$tparams](...$paramss): $tpt = $expr" :: _ => - if (args._2 != null) { - if (args._2 == "this") { - q"""def $tname[..$tparams](...$paramss): $tpt = ${This(TypeName(""))}.synchronized { $expr }""" + override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = { + val args: (Boolean, String) = extractArgumentsTuple2 { + case q"new synchronized(verbose=$verbose, lockedName=$lock)" => (evalTree(verbose.asInstanceOf[Tree]), evalTree(lock.asInstanceOf[Tree])) + case q"new synchronized(lockedName=$lock)" => (false, evalTree(lock.asInstanceOf[Tree])) + case q"new synchronized()" => (false, "this") + case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) + } + + val resTree = annottees map (_.tree) match { + // Match a method, and expand. + case _@ q"$modrs def $tname[..$tparams](...$paramss): $tpt = $expr" :: _ => + if (args._2 != null) { + if (args._2 == "this") { + q"""def $tname[..$tparams](...$paramss): $tpt = ${This(TypeName(""))}.synchronized { $expr }""" + } else { + q"""def $tname[..$tparams](...$paramss): $tpt = ${TermName(args._2)}.synchronized { $expr }""" + } } else { - q"""def $tname[..$tparams](...$paramss): $tpt = ${TermName(args._2)}.synchronized { $expr }""" + c.abort(c.enclosingPosition, "Invalid args, lockName cannot be a null!") } - } else { - c.abort(c.enclosingPosition, "Invalid args, lockName cannot be a null!") - } - case _ => c.abort(c.enclosingPosition, "Invalid annotation target: not a method") + case _ => c.abort(c.enclosingPosition, "Invalid annotation target: not a method") + } + printTree(args._1, resTree) + c.Expr[Any](resTree) } - printTree(c)(args._1, resTree) - c.Expr[Any](resTree) } + } diff --git a/src/main/scala/io/github/dreamylost/macros/toStringMacro.scala b/src/main/scala/io/github/dreamylost/macros/toStringMacro.scala index a92a2d8..af14c1c 100644 --- a/src/main/scala/io/github/dreamylost/macros/toStringMacro.scala +++ b/src/main/scala/io/github/dreamylost/macros/toStringMacro.scala @@ -29,132 +29,134 @@ import scala.reflect.macros.whitebox * @since 2021/7/7 * @version 1.0 */ -object toStringMacro extends MacroCommon { +object toStringMacro { private final case class Argument(verbose: Boolean, includeInternalFields: Boolean, includeFieldNames: Boolean, callSuper: Boolean) - def printField(c: whitebox.Context)(argument: Argument, lastParam: Option[String], field: c.universe.Tree): c.universe.Tree = { + class ToStringProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) { + import c.universe._ - // Print one field as +"="+fieldName - if (argument.includeFieldNames) { - lastParam.fold(q"$field") { lp => - field match { - case q"$mods var $tname: $tpt = $expr" => - if (tname.toString() != lp) q"""${tname.toString()}+${"="}+this.$tname+${", "}""" else q"""${tname.toString()}+${"="}+this.$tname""" - case q"$mods val $tname: $tpt = $expr" => - if (tname.toString() != lp) q"""${tname.toString()}+${"="}+this.$tname+${", "}""" else q"""${tname.toString()}+${"="}+this.$tname""" - case _ => q"$field" - } + + override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = { + // extract parameters of annotation, must in order + val arg: (Boolean, Boolean, Boolean, Boolean) = extractArgumentsTuple4 { + case q"new toString(includeInternalFields=$bb, includeFieldNames=$cc, callSuper=$dd)" => + (false, evalTree(bb.asInstanceOf[Tree]), evalTree(cc.asInstanceOf[Tree]), evalTree(dd.asInstanceOf[Tree])) + case q"new toString($aa, $bb, $cc)" => + (evalTree(aa.asInstanceOf[Tree]), evalTree(bb.asInstanceOf[Tree]), evalTree(cc.asInstanceOf[Tree]), false) + + case q"new toString(verbose=$aa, includeInternalFields=$bb, includeFieldNames=$cc, callSuper=$dd)" => + (evalTree(aa.asInstanceOf[Tree]), evalTree(bb.asInstanceOf[Tree]), evalTree(cc.asInstanceOf[Tree]), evalTree(dd.asInstanceOf[Tree])) + case q"new toString(verbose=$aa, includeInternalFields=$bb, includeFieldNames=$cc)" => + (evalTree(aa.asInstanceOf[Tree]), evalTree(bb.asInstanceOf[Tree]), evalTree(cc.asInstanceOf[Tree]), false) + case q"new toString($aa, $bb, $cc, $dd)" => + (evalTree(aa.asInstanceOf[Tree]), evalTree(bb.asInstanceOf[Tree]), evalTree(cc.asInstanceOf[Tree]), evalTree(dd.asInstanceOf[Tree])) + + case q"new toString(includeInternalFields=$bb, includeFieldNames=$cc)" => + (false, evalTree(bb.asInstanceOf[Tree]), evalTree(cc.asInstanceOf[Tree]), false) + case q"new toString(includeInternalFields=$bb)" => + (false, evalTree(bb.asInstanceOf[Tree]), true, false) + case q"new toString(includeFieldNames=$cc)" => + (false, true, evalTree(cc.asInstanceOf[Tree]), false) + case q"new toString()" => (false, true, true, false) + case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) } - } else { - lastParam.fold(q"$field") { lp => - field match { - case q"$mods var $tname: $tpt = $expr" => if (tname.toString() != lp) q"""$tname+${", "}""" else q"""$tname""" - case q"$mods val $tname: $tpt = $expr" => if (tname.toString() != lp) q"""$tname+${", "}""" else q"""$tname""" - case _ => if (field.toString() != lp) q"""$field+${", "}""" else q"""$field""" - } + val argument = Argument(arg._1, arg._2, arg._3, arg._4) + c.info(c.enclosingPosition, s"toString annottees: $annottees", force = argument.verbose) + // Check the type of the class, which can only be defined on the ordinary class + val annotateeClass: ClassDef = checkAndGetClassDef(annottees: _*) + val isCase: Boolean = isCaseClass(annotateeClass) + + c.info(c.enclosingPosition, s"impl argument: $argument, isCase: $isCase", force = argument.verbose) + val resMethod = toStringTemplateImpl(argument, annotateeClass) + val resTree = annotateeClass match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => + q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${stats.toList.:+(resMethod)} }" } + val res = treeResultWithCompanionObject(resTree, annottees: _*) + printTree(argument.verbose, res) + c.Expr[Any](res) } - } - private def toStringTemplateImpl(c: whitebox.Context)(argument: Argument, annotateeClass: c.universe.ClassDef): c.universe.Tree = { - import c.universe._ - // For a given class definition, separate the components of the class - val (className, annotteeClassParams, superClasses, annotteeClassDefinitions) = { - annotateeClass match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => - c.info(c.enclosingPosition, s"parents: $parents", force = argument.verbose) - (tpname, paramss.asInstanceOf[List[List[Tree]]], parents, stats.asInstanceOf[List[Tree]]) + def printField(argument: Argument, lastParam: Option[String], field: Tree): Tree = { + // Print one field as +"="+fieldName + if (argument.includeFieldNames) { + lastParam.fold(q"$field") { lp => + field match { + case q"$mods var $tname: $tpt = $expr" => + if (tname.toString() != lp) q"""${tname.toString()}+${"="}+this.$tname+${", "}""" else q"""${tname.toString()}+${"="}+this.$tname""" + case q"$mods val $tname: $tpt = $expr" => + if (tname.toString() != lp) q"""${tname.toString()}+${"="}+this.$tname+${", "}""" else q"""${tname.toString()}+${"="}+this.$tname""" + case _ => q"$field" + } + } + } else { + lastParam.fold(q"$field") { lp => + field match { + case q"$mods var $tname: $tpt = $expr" => if (tname.toString() != lp) q"""$tname+${", "}""" else q"""$tname""" + case q"$mods val $tname: $tpt = $expr" => if (tname.toString() != lp) q"""$tname+${", "}""" else q"""$tname""" + case _ => if (field.toString() != lp) q"""$field+${", "}""" else q"""$field""" + } + } + } } - // Check the type of the class, whether it already contains its own toString - val annotteeClassFieldDefinitions = annotteeClassDefinitions.filter(p => p match { - case _: ValDef => true - case mem: MemberDef => - c.info(c.enclosingPosition, s"MemberDef: ${mem.toString}", force = argument.verbose) - if (mem.toString().startsWith("override def toString")) { // TODO better way - c.abort(mem.pos, "'toString' method has already defined, please remove it or not use'@toString'") + + private def toStringTemplateImpl(argument: Argument, annotateeClass: ClassDef): Tree = { + // For a given class definition, separate the components of the class + val (className, annotteeClassParams, superClasses, annotteeClassDefinitions) = { + annotateeClass match { + case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => + c.info(c.enclosingPosition, s"parents: $parents", force = argument.verbose) + (tpname, paramss.asInstanceOf[List[List[Tree]]], parents, stats.asInstanceOf[List[Tree]]) } - false - case _ => false - }) - - // For the parameters of a given constructor, separate the parameter components and extract the constructor parameters containing val and var - val ctorParams = annotteeClassParams.flatten.map { - case tree @ q"$mods val $tname: $tpt = $expr" => tree - case tree @ q"$mods var $tname: $tpt = $expr" => tree - } - c.info(c.enclosingPosition, s"className: $className, ctorParams: ${ctorParams.toString()}, superClasses: $superClasses", force = argument.verbose) - c.info(c.enclosingPosition, s"className: $className, fields: ${annotteeClassFieldDefinitions.toString()}", force = argument.verbose) - val member = if (argument.includeInternalFields) ctorParams ++ annotteeClassFieldDefinitions else ctorParams + } + // Check the type of the class, whether it already contains its own toString + val annotteeClassFieldDefinitions = annotteeClassDefinitions.filter(p => p match { + case _: ValDef => true + case mem: MemberDef => + c.info(c.enclosingPosition, s"MemberDef: ${mem.toString}", force = argument.verbose) + if (mem.toString().startsWith("override def toString")) { // TODO better way + c.abort(mem.pos, "'toString' method has already defined, please remove it or not use'@toString'") + } + false + case _ => false + }) + + // For the parameters of a given constructor, separate the parameter components and extract the constructor parameters containing val and var + val ctorParams = annotteeClassParams.flatten.map { + case tree @ q"$mods val $tname: $tpt = $expr" => tree + case tree @ q"$mods var $tname: $tpt = $expr" => tree + } + c.info(c.enclosingPosition, s"className: $className, ctorParams: ${ctorParams.toString()}, superClasses: $superClasses", force = argument.verbose) + c.info(c.enclosingPosition, s"className: $className, fields: ${annotteeClassFieldDefinitions.toString()}", force = argument.verbose) + val member = if (argument.includeInternalFields) ctorParams ++ annotteeClassFieldDefinitions else ctorParams - val lastParam = member.lastOption.map { - case v: ValDef => v.name.toTermName.decodedName.toString - case c => c.toString - } - val paramsWithName = member.foldLeft(q"${""}")((res, acc) => q"$res + ${printField(c)(argument, lastParam, acc)}") - //scala/bug https://github.com/scala/bug/issues/3967 not be 'Foo(i=1,j=2)' in standard library - val toString = q"""override def toString: String = ${className.toString()} + ${"("} + $paramsWithName + ${")"}""" - - // Have super class ? - if (argument.callSuper && superClasses.nonEmpty) { - val superClassDef = superClasses.head match { - case tree: Tree => Some(tree) // TODO type check better - case _ => None + val lastParam = member.lastOption.map { + case v: ValDef => v.name.toTermName.decodedName.toString + case c => c.toString } - superClassDef.fold(toString)(_ => { - val superClass = q"${"super="}" - c.info(c.enclosingPosition, s"member: $member, superClass: $superClass, superClassDef: $superClassDef, paramsWithName: $paramsWithName", force = argument.verbose) - q"override def toString: String = StringContext(${className.toString()} + ${"("} + $superClass, ${if (member.nonEmpty) ", " else ""}+$paramsWithName + ${")"}).s(super.toString)" + val paramsWithName = member.foldLeft(q"${""}")((res, acc) => q"$res + ${printField(argument, lastParam, acc)}") + //scala/bug https://github.com/scala/bug/issues/3967 not be 'Foo(i=1,j=2)' in standard library + val toString = q"""override def toString: String = ${className.toString()} + ${"("} + $paramsWithName + ${")"}""" + + // Have super class ? + if (argument.callSuper && superClasses.nonEmpty) { + val superClassDef = superClasses.head match { + case tree: Tree => Some(tree) // TODO type check better + case _ => None + } + superClassDef.fold(toString)(_ => { + val superClass = q"${"super="}" + c.info(c.enclosingPosition, s"member: $member, superClass: $superClass, superClassDef: $superClassDef, paramsWithName: $paramsWithName", force = argument.verbose) + q"override def toString: String = StringContext(${className.toString()} + ${"("} + $superClass, ${if (member.nonEmpty) ", " else ""}+$paramsWithName + ${")"}).s(super.toString)" + } + ) + } else { + toString } - ) - } else { - toString } - } - def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { - import c.universe._ - // extract parameters of annotation, must in order - val arg: (Boolean, Boolean, Boolean, Boolean) = extractArgumentsTuple4(c) { - case q"new toString(includeInternalFields=$bb, includeFieldNames=$cc, callSuper=$dd)" => - (false, evalTree(c)(bb.asInstanceOf[Tree]), evalTree(c)(cc.asInstanceOf[Tree]), evalTree(c)(dd.asInstanceOf[Tree])) - case q"new toString($aa, $bb, $cc)" => - (evalTree(c)(aa.asInstanceOf[Tree]), evalTree(c)(bb.asInstanceOf[Tree]), evalTree(c)(cc.asInstanceOf[Tree]), false) - - case q"new toString(verbose=$aa, includeInternalFields=$bb, includeFieldNames=$cc, callSuper=$dd)" => - (evalTree(c)(aa.asInstanceOf[Tree]), evalTree(c)(bb.asInstanceOf[Tree]), evalTree(c)(cc.asInstanceOf[Tree]), evalTree(c)(dd.asInstanceOf[Tree])) - case q"new toString(verbose=$aa, includeInternalFields=$bb, includeFieldNames=$cc)" => - (evalTree(c)(aa.asInstanceOf[Tree]), evalTree(c)(bb.asInstanceOf[Tree]), evalTree(c)(cc.asInstanceOf[Tree]), false) - case q"new toString($aa, $bb, $cc, $dd)" => - (evalTree(c)(aa.asInstanceOf[Tree]), evalTree(c)(bb.asInstanceOf[Tree]), evalTree(c)(cc.asInstanceOf[Tree]), evalTree(c)(dd.asInstanceOf[Tree])) - - case q"new toString(includeInternalFields=$bb, includeFieldNames=$cc)" => - (false, evalTree(c)(bb.asInstanceOf[Tree]), evalTree(c)(cc.asInstanceOf[Tree]), false) - case q"new toString(includeInternalFields=$bb)" => - (false, evalTree(c)(bb.asInstanceOf[Tree]), true, false) - case q"new toString(includeFieldNames=$cc)" => - (false, true, evalTree(c)(cc.asInstanceOf[Tree]), false) - case q"new toString()" => (false, true, true, false) - case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN) - } - val argument = Argument(arg._1, arg._2, arg._3, arg._4) - c.info(c.enclosingPosition, s"toString annottees: $annottees", force = argument.verbose) - // Check the type of the class, which can only be defined on the ordinary class - val annotateeClass: ClassDef = checkAndGetClassDef(c)(annottees: _*) - val isCase: Boolean = isCaseClass(c)(annotateeClass) - - c.info(c.enclosingPosition, s"impl argument: $argument, isCase: $isCase", force = argument.verbose) - val resMethod = toStringTemplateImpl(c)(argument, annotateeClass) - val resTree = annotateeClass match { - case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => - q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${stats.toList.:+(resMethod)} }" - } - - val res = treeResultWithCompanionObject(c)(resTree, annottees: _*) - printTree(c)(argument.verbose, res) - c.Expr[Any](res) - } } diff --git a/src/main/scala/io/github/dreamylost/synchronized.scala b/src/main/scala/io/github/dreamylost/synchronized.scala index 24cabdf..c79059a 100644 --- a/src/main/scala/io/github/dreamylost/synchronized.scala +++ b/src/main/scala/io/github/dreamylost/synchronized.scala @@ -39,5 +39,5 @@ final class synchronized( verbose: Boolean = false, lockedName: String = "this" ) extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro synchronizedMacro.impl + def macroTransform(annottees: Any*): Any = macro synchronizedMacro.SynchronizedProcessor.impl } diff --git a/src/main/scala/io/github/dreamylost/toString.scala b/src/main/scala/io/github/dreamylost/toString.scala index 49bd034..2a74c7a 100644 --- a/src/main/scala/io/github/dreamylost/toString.scala +++ b/src/main/scala/io/github/dreamylost/toString.scala @@ -43,5 +43,5 @@ final class toString( includeFieldNames: Boolean = true, callSuper: Boolean = false ) extends StaticAnnotation { - def macroTransform(annottees: Any*): Any = macro toStringMacro.impl + def macroTransform(annottees: Any*): Any = macro toStringMacro.ToStringProcessor.impl } diff --git a/src/test/scala/io/github/dreamylost/ApplyTest.scala b/src/test/scala/io/github/dreamylost/ApplyTest.scala index d3cb72f..320072d 100644 --- a/src/test/scala/io/github/dreamylost/ApplyTest.scala +++ b/src/test/scala/io/github/dreamylost/ApplyTest.scala @@ -54,20 +54,24 @@ class ApplyTest extends AnyFlatSpec with Matchers { object B3 println(B3(1, 2, None, None)) } - "apply2" should "failed at class" in { - // FAILED, not support currying!! - """@apply @toString class C(int: Int, val j: Int, var k: Option[String] = None, t: Option[Long] = Some(1L))(o: Int = 1)""" shouldNot compile + "apply2" should "failed on case class" in { + """@apply @toString case class C3(int: Int, val j: Int, var k: Option[String] = None, t: Option[Long] = Some(1L))(o: Int = 1)""" shouldNot compile } "apply3" should "ok with currying" in { + """@apply @toString class C2(int: Int, val j: Int, var k: Option[String] = None, t: Option[Long] = Some(1L))(o: Int = 1)""" should compile + @apply + @toString class C1(int: Int, val j: Int, var k: Option[String] = None, t: Option[Long] = Some(1L))(o: Int = 1) @apply @toString class B3(int: Int)(val j: Int)(var k: Option[String] = None)(t: Option[Long] = Some(1L)) + @apply + @toString class B4(int: Int, a: Seq[Seq[String]])(val j: Int, b: Seq[String])(var k: Option[String] = None, c: Seq[Option[String]])(t: Option[Long] = Some(1L)) } "apply4" should "ok with generic" in { @apply - @toString class B3[T, U](int: T)(val j: U) - println(B3(1)(2)) + @toString class B3[T, U](int: T, yy: Int)(val j: U) + println(B3(1, 2)(2)) @toString @apply class B4[T, U](int: T, val j: U) diff --git a/src/test/scala/io/github/dreamylost/BuilderTest.scala b/src/test/scala/io/github/dreamylost/BuilderTest.scala index 36113c2..e9e5860 100644 --- a/src/test/scala/io/github/dreamylost/BuilderTest.scala +++ b/src/test/scala/io/github/dreamylost/BuilderTest.scala @@ -38,6 +38,7 @@ class BuilderTest extends AnyFlatSpec with Matchers { // field : val i: Int = 0, so default value is "_" val ret = TestClass1.builder().i(1).j(0).x("x").build() println(ret) + assert(TestClass1.builder().getClass.getTypeName == "io.github.dreamylost.BuilderTest$TestClass1$2$TestClass1Builder") assert(ret.toString == "TestClass1(1,0,x,Some())") } diff --git a/src/test/scala/io/github/dreamylost/EqualsAndHashCodeTest.scala b/src/test/scala/io/github/dreamylost/EqualsAndHashCodeTest.scala index a731c20..a363a5d 100644 --- a/src/test/scala/io/github/dreamylost/EqualsAndHashCodeTest.scala +++ b/src/test/scala/io/github/dreamylost/EqualsAndHashCodeTest.scala @@ -98,11 +98,41 @@ class EqualsAndHashCodeTest extends AnyFlatSpec with Matchers { } "equals3" should "ok even if exists a canEqual" in { + @equalsAndHashCode + class Employee1(name: String, age: Int, var role: String) extends Person(name, age) { + override def canEqual(that: Any) = that.getClass == classOf[Employee1]; + } """ | @equalsAndHashCode - | class Employee(name: String, age: Int, var role: String) extends Person(name, age) { + | class Employee2(name: String, age: Int, var role: String) extends Person(name, age) { | override def canEqual(that: Any) = that.getClass == classOf[Employee]; | } |""".stripMargin should compile } + + "equals4" should "ok when there are members" in { + @equalsAndHashCode + class Employee1(name: String, age: Int, var role: String) extends Person(name, age) { + val i = 0 + } + """ + | @equalsAndHashCode + | class Employee2(name: String, age: Int, var role: String) extends Person(name, age) { + | val i = 0 + | } + |""".stripMargin should compile + + @equalsAndHashCode + class Employee3(name: String, age: Int, var role: String) extends Person(name, age) { + val i = 0 + def hello: String = ??? + } + """ + | @equalsAndHashCode + | class Employee4(name: String, age: Int, var role: String) extends Person(name, age) { + | val i = 0 + | def hello: String = ??? + | } + |""".stripMargin should compile + } } diff --git a/src/test/scala/io/github/dreamylost/LogTest.scala b/src/test/scala/io/github/dreamylost/LogTest.scala index 1b66c5e..51eed26 100644 --- a/src/test/scala/io/github/dreamylost/LogTest.scala +++ b/src/test/scala/io/github/dreamylost/LogTest.scala @@ -141,22 +141,29 @@ class LogTest extends AnyFlatSpec with Matchers { |""".stripMargin should compile } - "log10 slf4j" should "failed on case class and it object" in { + "log10 slf4j" should "ok on case class and it object" in { + @log(logType = LogType.JLog) + @builder case class TestClass6_1(val i: Int = 0, var j: Int) { + log.info("hello world") + } + @log(logType = io.github.dreamylost.logs.LogType.Slf4j) object TestClass6_1 { + log.info("hello world"); builder() + } """ - | @log(io.github.dreamylost.logs.LogType.Slf4j) - | @builder case class TestClass6(val i: Int = 0, var j: Int) { + | @log(verbose=false, logType = LogType.JLog) + | @builder case class TestClass6_2(val i: Int = 0, var j: Int) { | log.info("hello world") | } - | @log(logType = io.github.dreamylost.logs.LogType.Slf4j) object TestClass6 { + | @log(logType = io.github.dreamylost.logs.LogType.Slf4j) object TestClass6_2 { | log.info("hello world"); builder() | } - |""".stripMargin shouldNot compile //The context of class was not passed in object macro + |""".stripMargin should compile } "log11 slf4j" should "ok on class and it object" in { """ | @log(io.github.dreamylost.logs.LogType.Slf4j) - | @builder class TestClass6(val i: Int = 0, var j: Int) { + | @builder class TestClass6(val i: Int = 0, var j: Int) { | log.info("hello world") | } | @log(logType = io.github.dreamylost.logs.LogType.Slf4j) object TestClass6 { diff --git a/src/test/scala/io/github/dreamylost/OthersTest.scala b/src/test/scala/io/github/dreamylost/OthersTest.scala new file mode 100644 index 0000000..c82d9e2 --- /dev/null +++ b/src/test/scala/io/github/dreamylost/OthersTest.scala @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 jxnu-liguobin && contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package io.github.dreamylost + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +/** + * + * @author 梦境迷离 + * @since 2021/7/24 + * @version 1.0 + */ +class OthersTest extends AnyFlatSpec with Matchers { + "others" should "ok" in { + assert(PACKAGE == "io.github.dreamylost") + + """ + | @builder + | object A + |""".stripMargin shouldNot compile + } +} -- GitLab