提交 8316c627 编写于 作者: 梦境迷离's avatar 梦境迷离

refactor

上级 d9df483c
......@@ -48,6 +48,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
*
*/
def createCustomExpr(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = ???
def createCustomExpr(classDeclOpt: Option[ClassDef], compDeclOpt: Option[ModuleDef]): Any = ???
/**
* Subclasses must override the method.
......@@ -98,23 +99,49 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @param annottees
* @return Return ClassDef
*/
def checkAndGetClassDef(annottees: Seq[Expr[Any]]): ClassDef = {
def checkGetClassDef(annottees: Seq[Expr[Any]]): ClassDef = {
annottees.map(_.tree).toList match {
case (classDecl: ClassDef) :: Nil => classDecl
case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => classDecl
case (classDecl: ClassDef) :: (_: ModuleDef) :: Nil => classDecl
case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN)
}
}
def uncheckGetClassDef(annottees: Seq[Expr[Any]]): Option[ClassDef] = {
annottees.map(_.tree).toList match {
case (classDecl: ClassDef) :: Nil => Some(classDecl)
case (classDecl: ClassDef) :: (_: ModuleDef) :: Nil => Some(classDecl)
case (_: ModuleDef) :: (classDecl: ClassDef) :: Nil => Some(classDecl)
case _ => None
}
}
/**
* Get companion object if it exists.
* Get the class or object.
*
* @param annottees
* @return Return ClassDef or ModuleDef
*/
def checkGetClassDefOrModuleDef(annottees: Seq[Expr[Any]]): c.universe.ImplDef = {
annottees.map(_.tree).toList match {
case (classDecl: ClassDef) :: Nil => classDecl
case (moduleDef: ModuleDef) :: Nil => moduleDef
case (classDecl: ClassDef) :: (_: ModuleDef) :: Nil => classDecl
case (_: ModuleDef) :: (classDecl: ClassDef) :: Nil => classDecl
case _ => c.abort(c.enclosingPosition, ErrorMessage.ONLY_OBJECT_CLASS)
}
}
/**
* Get object if it exists.
*
* @param annottees
* @return
*/
def getCompanionObject(annottees: Seq[Expr[Any]]): Option[ModuleDef] = {
def getModuleDefOption(annottees: Seq[Expr[Any]]): Option[ModuleDef] = {
annottees.map(_.tree).toList match {
case (_: ClassDef) :: Nil => None
case (compDecl: ModuleDef) :: (_: ClassDef) :: Nil => Some(compDecl)
case (_: ClassDef) :: (compDecl: ModuleDef) :: Nil => Some(compDecl)
case (compDecl: ModuleDef) :: Nil => Some(compDecl)
case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN)
......@@ -128,8 +155,8 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @param annottees
* @return
*/
def returnWithCompanionObject(resTree: Tree, annottees: Seq[Expr[Any]]): Tree = {
val companionOpt = getCompanionObject(annottees)
def returnWithModuleDef(resTree: Tree, annottees: Seq[Expr[Any]]): Tree = {
val companionOpt = getModuleDefOption(annottees)
companionOpt.fold(resTree) { t =>
q"""
$resTree
......@@ -147,8 +174,8 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
*/
def collectCustomExpr(annottees: Seq[Expr[Any]])
(modifyAction: (ClassDef, Option[ModuleDef]) => Any): Expr[Nothing] = {
val classDef = checkAndGetClassDef(annottees)
val compDecl = getCompanionObject(annottees)
val classDef = checkGetClassDef(annottees)
val compDecl = getModuleDefOption(annottees)
modifyAction(classDef, compDecl).asInstanceOf[Expr[Nothing]]
}
......@@ -389,7 +416,17 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
def mapToClassDeclInfo(classDecl: ClassDef): ClassDefinition = {
val q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" = classDecl
val (className, classParamss, classTypeParams) = (tpname.asInstanceOf[TypeName], paramss.asInstanceOf[List[List[Tree]]], tparams.asInstanceOf[List[Tree]])
ClassDefinition(className, classParamss, classTypeParams, stats.asInstanceOf[List[Tree]], parents.asInstanceOf[List[Tree]])
ClassDefinition(self.asInstanceOf[ValDef], mods.asInstanceOf[Modifiers], className, classParamss, classTypeParams, stats.asInstanceOf[List[Tree]], parents.asInstanceOf[List[Tree]])
}
/**
* Extract the necessary structure information of the moduleDef for macro programming.
*
* @param moduleDef
*/
def mapToModuleDeclInfo(moduleDef: ModuleDef): ClassDefinition = {
val q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" = moduleDef
ClassDefinition(self.asInstanceOf[ValDef], mods.asInstanceOf[Modifiers], tpname.asInstanceOf[TermName].toTypeName, Nil, Nil, stats.asInstanceOf[List[Tree]], parents.asInstanceOf[List[Tree]])
}
/**
......@@ -400,13 +437,37 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @param classInfoAction Content body added in class definition
* @return
*/
def appendClassBody(classDecl: ClassDef, classInfoAction: ClassDefinition => Seq[Tree]): c.universe.ClassDef = {
def appendClassBody(classDecl: ClassDef, classInfoAction: ClassDefinition => List[Tree]): c.universe.ClassDef = {
val classInfo = mapToClassDeclInfo(classDecl)
val ClassDef(mods, name, tparams, impl) = classDecl
val Template(parents, self, body) = impl
ClassDef(mods, name, tparams, Template(parents, self, body ++ classInfoAction(classInfo)))
}
// def prependClassBody(classDecl: ClassDef, classInfoAction: ClassDefinition => List[Tree]): c.universe.ClassDef = {
// val classInfo = mapToClassDeclInfo(classDecl)
// val ClassDef(mods, name, tparams, impl) = classDecl
// val Template(parents, self, body) = impl
// ClassDef(mods, name, tparams, Template(parents, self, classInfoAction(classInfo) ++ body))
// }
//
// def appendClassSuper(classDecl: ClassDef, classInfoAction: ClassDefinition => List[Tree]): c.universe.ClassDef = {
// val classInfo = mapToClassDeclInfo(classDecl)
// val ClassDef(mods, name, tparams, impl) = classDecl
// val Template(parents, self, body) = impl
// ClassDef(mods, name, tparams, Template(parents ++ classInfoAction(classInfo), self, body))
// }
//
// def appendModuleSuper(moduleDef: ModuleDef, action: => List[Tree]): c.universe.Tree = {
// val classDefinition = mapToModuleDeclInfo(moduleDef)
// q"${classDefinition.mods} object ${classDefinition.className.toTermName} extends { ..${classDefinition.earlydefns} } with ..${classDefinition.superClasses ++ action} { ${classDefinition.self} => ..${classDefinition.body} }"
// }
//
// def prependModuleBody(moduleDef: ModuleDef, action: => List[Tree]): c.universe.Tree = {
// val classDefinition = mapToModuleDeclInfo(moduleDef)
// q"${classDefinition.mods} object ${classDefinition.className.toTermName} extends { ..${classDefinition.earlydefns} } with ..${classDefinition.superClasses} { ${classDefinition.self} => ..${action ++ classDefinition.body} }"
// }
/**
* Modify the method body of the method tree.
*
......@@ -420,11 +481,14 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
}
private[macros] case class ClassDefinition(
self: ValDef,
mods: Modifiers,
className: TypeName,
classParamss: List[List[Tree]],
classTypeParams: List[Tree],
body: List[Tree],
superClasses: List[Tree]
superClasses: List[Tree],
earlydefns: List[Tree] = Nil
)
}
......@@ -55,7 +55,7 @@ object applyMacro {
}
override def impl(annottees: Expr[Any]*): Expr[Any] = {
val annotateeClass: ClassDef = checkAndGetClassDef(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
if (isCaseClass(annotateeClass)) {
c.abort(c.enclosingPosition, ErrorMessage.ONLY_CASE_CLASS)
}
......
......@@ -101,7 +101,7 @@ object constructorMacro {
override def createCustomExpr(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = {
val resTree = appendClassBody(
classDecl,
classInfo => Seq(getThisMethodWithCurrying(classInfo.classParamss, classInfo.body)))
classInfo => List(getThisMethodWithCurrying(classInfo.classParamss, classInfo.body)))
c.Expr(
q"""
${compDeclOpt.fold(EmptyTree)(x => x)}
......@@ -110,7 +110,7 @@ object constructorMacro {
}
override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = {
val annotateeClass: ClassDef = checkAndGetClassDef(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
if (isCaseClass(annotateeClass)) {
c.abort(c.enclosingPosition, ErrorMessage.ONLY_CLASS)
}
......
......@@ -44,7 +44,7 @@ object equalsAndHashCodeMacro {
}
override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = {
val annotateeClass: ClassDef = checkAndGetClassDef(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
if (isCaseClass(annotateeClass)) {
c.abort(c.enclosingPosition, ErrorMessage.ONLY_CLASS)
}
......
......@@ -44,7 +44,7 @@ object jsonMacro {
}
override def impl(annottees: c.universe.Expr[Any]*): c.universe.Expr[Any] = {
val annotateeClass: ClassDef = checkAndGetClassDef(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
if (!isCaseClass(annotateeClass)) {
c.abort(c.enclosingPosition, ErrorMessage.ONLY_CASE_CLASS)
}
......
......@@ -72,12 +72,20 @@ object logMacro {
}
val resTree = annottees.map(_.tree) match {
case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ =>
extractArgumentsDetail._2 match {
if(mods.asInstanceOf[Modifiers].hasFlag(Flag.CASE)){
c.abort(c.enclosingPosition, ErrorMessage.ONLY_OBJECT_CLASS)
}
val newClass = extractArgumentsDetail._2 match {
case ScalaLoggingLazy | ScalaLoggingStrict =>
q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..${parents ++ Seq(logTree)} { $self => ..$stats }"
case _ =>
q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${Seq(logTree) ++ stats} }"
}
val moduleDef = getModuleDefOption(annottees)
q"""
${if (moduleDef.isEmpty) EmptyTree else moduleDef.get}
$newClass
"""
case q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: _ =>
extractArgumentsDetail._2 match {
case ScalaLoggingLazy | ScalaLoggingStrict =>
......@@ -89,8 +97,7 @@ object logMacro {
// see https://docs.scala-lang.org/overviews/macros/annotations.html
}
val res = returnWithCompanionObject(resTree, annottees)
printTree(force = extractArgumentsDetail._1, res)
printTree(force = extractArgumentsDetail._1, resTree)
c.Expr[Any](resTree)
}
}
......
......@@ -72,9 +72,9 @@ object toStringMacro {
val argument = Argument(extractArgumentsDetail._1, extractArgumentsDetail._2,
extractArgumentsDetail._3, extractArgumentsDetail._4)
// Check the type of the class, which can only be defined on the ordinary class
val annotateeClass: ClassDef = checkAndGetClassDef(annottees)
val resTree = appendClassBody(annotateeClass, _ => Seq(toStringTemplateImpl(argument, annotateeClass)))
val compDeclOpt = getCompanionObject(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
val resTree = appendClassBody(annotateeClass, _ => List(toStringTemplateImpl(argument, annotateeClass)))
val compDeclOpt = getModuleDefOption(annottees)
val res = c.Expr(
q"""
${compDeclOpt.fold(EmptyTree)(x => x)}
......
......@@ -45,18 +45,6 @@ class LogTest extends AnyFlatSpec with Matchers {
"""@log(verbose=true, logType=io.github.dreamylost.logs.LogType.JLog) class TestClass6(val i: Int = 0, var j: Int)""" should compile
}
"log2" should "ok on case class" in {
"""@log(verbose=true) case class TestClass1(val i: Int = 0, var j: Int) {
log.info("hello")
}""" should compile
"""@log case class TestClass2(val i: Int = 0, var j: Int)""" should compile
"""@log() case class TestClass3(val i: Int = 0, var j: Int)""" should compile
"""@log(verbose=true) case class TestClass4(val i: Int = 0, var j: Int)""" should compile
"""@log(logType=io.github.dreamylost.logs.LogType.JLog) case class TestClass5(val i: Int = 0, var j: Int)""" should compile
"""@log(verbose=true, logType=io.github.dreamylost.logs.LogType.JLog) case class TestClass6(val i: Int = 0, var j: Int)""" should compile
}
"log3" should "ok on object" in {
"""@log(verbose=true) object TestClass1 {
log.info("hello")
......@@ -141,15 +129,7 @@ class LogTest extends AnyFlatSpec with Matchers {
|""".stripMargin should compile
}
"log10 slf4j" should "ok on case class and it object" in {
@log(logType = LogType.JLog)
@builder case class TestClass6_1(val i: Int = 0, var j: Int) {
log.info("hello world")
}
@log(logType = io.github.dreamylost.logs.LogType.Slf4j) object TestClass6_1 {
log.info("hello world");
builder()
}
"log10 slf4j" should "failed on case class" in {
"""
| @log(verbose=false, logType = LogType.JLog)
| @builder case class TestClass6_2(val i: Int = 0, var j: Int) {
......@@ -158,7 +138,7 @@ class LogTest extends AnyFlatSpec with Matchers {
| @log(logType = io.github.dreamylost.logs.LogType.Slf4j) object TestClass6_2 {
| log.info("hello world"); builder()
| }
|""".stripMargin should compile
|""".stripMargin shouldNot compile
}
"log11 slf4j" should "ok on class and it object" in {
......@@ -242,14 +222,6 @@ class LogTest extends AnyFlatSpec with Matchers {
| log.info("hello world")
| }
|""".stripMargin should compile
"""
| import io.github.dreamylost.logs.LogType
| @log(logType = LogType.ScalaLoggingLazy)
| case class TestClass5(val i: Int = 0, var j: Int) {
| log.info("hello world")
| }
|""".stripMargin should compile
}
"log14 scala loggging strict" should "ok when exists super class" in {
......@@ -281,20 +253,8 @@ class LogTest extends AnyFlatSpec with Matchers {
| log.info("hello world")
| }
|""".stripMargin should compile
"""
| import io.github.dreamylost.logs.LogType
| @log(logType = LogType.ScalaLoggingStrict)
| case class TestClass5(val i: Int = 0, var j: Int) extends Serializable {
| log.info("hello world")
| }
|""".stripMargin should compile
}
// We must define the class outside so that the macro has been compiled before testing.
@log(logType = LogType.ScalaLoggingStrict)
@json case class TestClass1(val i: Int = 0, var j: Int, x: String, o: Option[String] = Some(""))
"log15 add @transient" should "ok" in {
"""
|val str = Json.toJson(TestClass1(1, 1, "hello")).toString()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册