提交 004baeff 编写于 作者: 梦境迷离's avatar 梦境迷离

refactor

上级 43af7786
......@@ -198,7 +198,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @param annotteeClassParams
* @return {{ i: Int}}
*/
def getConstructorFieldNameWithType(annotteeClassParams: Seq[Tree]): Seq[Tree] = {
def getConstructorParamsNameWithType(annotteeClassParams: Seq[Tree]): Seq[Tree] = {
annotteeClassParams.map {
case v: ValDef => q"${v.name}: ${v.tpt}"
}
......@@ -239,11 +239,11 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
*
* @param annotteeClassDefinitions
*/
def getClassMemberValDefs(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = {
def getClassMemberValDefs(annotteeClassDefinitions: Seq[Tree]): Seq[ValDef] = {
annotteeClassDefinitions.filter(_ match {
case _: ValDef => true
case _ => false
})
}).map(_.asInstanceOf[ValDef])
}
/**
......@@ -251,11 +251,11 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
*
* @param annotteeClassDefinitions
*/
def getClassMemberDefDefs(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = {
def getClassMemberDefDefs(annotteeClassDefinitions: Seq[Tree]): Seq[DefDef] = {
annotteeClassDefinitions.filter(_ match {
case _: DefDef => true
case _ => false
})
}).map(_.asInstanceOf[DefDef])
}
/**
......@@ -293,7 +293,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @example {{ def apply(int: Int)(j: Int)(k: Option[String])(t: Option[Long]): B3 = new B3(int)(j)(k)(t) }}
*/
def getApplyMethodWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], classTypeParams: List[Tree]): Tree = {
val allFieldsTermName = fieldss.map(f => getConstructorFieldNameWithType(f))
val allFieldsTermName = fieldss.map(f => getConstructorParamsNameWithType(f))
val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams)
// not currying
val applyMethod = if (fieldss.isEmpty || fieldss.size == 1) {
......
......@@ -48,13 +48,13 @@ object builderMacro {
private def getFieldSetMethod(typeName: TypeName, field: Tree, classTypeParams: List[Tree]): Tree = {
val builderClassName = getBuilderClassName(typeName)
val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams)
val valDefMapTo = (v: ValDef) => {
lazy val valDefMapTo = (v: ValDef) => {
q"""
def ${v.name}(${v.name}: ${v.tpt}): $builderClassName[..$returnTypeParams] = {
this.${v.name} = ${v.name}
this
}
"""
def ${v.name}(${v.name}: ${v.tpt}): $builderClassName[..$returnTypeParams] = {
this.${v.name} = ${v.name}
this
}
"""
}
field match {
case v: ValDef => valDefMapTo(v)
......
......@@ -45,18 +45,23 @@ object constructorMacro {
}
}
private def getMutableValDefAndExcludeFields(annotteeClassDefinitions: Seq[Tree]): Seq[c.universe.ValDef] = {
getClassMemberValDefs(annotteeClassDefinitions).filter(v => v.mods.hasFlag(Flag.MUTABLE) &&
!extractArgumentsDetail._2.contains(v.name.decodedName.toString))
}
/**
* Extract the internal fields of members belonging to the class, but not in primary constructor and only `var`.
*/
private def getClassMemberVarDefOnlyAssignExpr(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = {
getClassMemberValDefs(annotteeClassDefinitions).filter(_ match {
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) => !extractArgumentsDetail._2.contains(v.name.decodedName.toString)
case _ => false
}).map {
case q"$mods var $pat = $expr" =>
// TODO getClass RETURN a java type, maybe we can try use class reflect to get the fields type name.
q"$pat: ${TypeName(toScalaType(evalTree(expr.asInstanceOf[Tree]).getClass.getTypeName))}"
case q"$mods var $tname: $tpt = $expr" => q"$tname: $tpt"
private def getMemberVarDefTermNameWithType(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = {
getMutableValDefAndExcludeFields(annotteeClassDefinitions).map {
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) =>
if (v.tpt.isEmpty) { // val i = 1, tpt is `<type ?>`
// TODO getClass RETURN a java type, maybe we can try use class reflect to get the fields type name.
q"${v.name}: ${TypeName(toScalaType(evalTree(v.rhs).getClass.getTypeName))}"
} else {
q"${v.name}: ${v.tpt}"
}
}
}
......@@ -64,18 +69,13 @@ object constructorMacro {
* We generate this method with currying, and we have to deal with the first layer of currying alone.
*/
private def getThisMethodWithCurrying(annotteeClassParams: List[List[Tree]], annotteeClassDefinitions: Seq[Tree]): Tree = {
val classFieldDefinitionsOnlyAssignExpr = getClassMemberVarDefOnlyAssignExpr(annotteeClassDefinitions)
val classInternalFieldsWithType = getMemberVarDefTermNameWithType(annotteeClassDefinitions)
if (classFieldDefinitionsOnlyAssignExpr.isEmpty) {
if (classInternalFieldsWithType.isEmpty) {
c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} and the internal fields (declare as 'var') should not be Empty.")
}
// Extract the internal fields of members belonging to the class, but not in primary constructor.
val classFieldDefinitions = getClassMemberValDefs(annotteeClassDefinitions)
val annotteeClassFieldNames = classFieldDefinitions.filter(_ match {
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) => !extractArgumentsDetail._2.contains(v.name.decodedName.toString)
case _ => false
}).map {
val annotteeClassFieldNames = getMutableValDefAndExcludeFields(annotteeClassDefinitions).map {
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) => v.name
}
......@@ -85,19 +85,19 @@ object constructorMacro {
})
// Extract the field of the primary constructor.
val classParamsAssignExpr = getConstructorFieldNameWithType(annotteeClassParams.flatten)
val classParamsNameWithType = getConstructorParamsNameWithType(annotteeClassParams.flatten)
val applyMethod = if (annotteeClassParams.isEmpty || annotteeClassParams.size == 1) {
q"""
def this(..${classParamsAssignExpr ++ classFieldDefinitionsOnlyAssignExpr}) = {
def this(..${classParamsNameWithType ++ classInternalFieldsWithType}) = {
this(..${allFieldsTermName.flatten})
..${annotteeClassFieldNames.map(f => q"this.$f = $f")}
}
"""
} else {
// NOTE: currying constructor overload must be placed in the first bracket block.
val allClassParamsAssignExpr = annotteeClassParams.map(cc => getConstructorFieldNameWithType(cc))
val allClassCtorParamsNameWithType = annotteeClassParams.map(cc => getConstructorParamsNameWithType(cc))
q"""
def this(..${allClassParamsAssignExpr.head ++ classFieldDefinitionsOnlyAssignExpr})(...${allClassParamsAssignExpr.tail}) = {
def this(..${allClassCtorParamsNameWithType.head ++ classInternalFieldsWithType})(...${allClassCtorParamsNameWithType.tail}) = {
this(..${allFieldsTermName.head})(...${allFieldsTermName.tail})
..${annotteeClassFieldNames.map(f => q"this.$f = $f")}
}
......
......@@ -67,22 +67,18 @@ object equalsAndHashCodeMacro {
/**
* Extract the internal fields of members belonging to the class.
*/
private def getInternalFieldTermNameExcludeLocal(annotteeClassDefinitions: Seq[Tree]): Seq[TermName] = {
getClassMemberValDefs(annotteeClassDefinitions).filter(p => isNotLocalClassMember(p) && (p match {
case v: ValDef => !extractArgumentsDetail._2.contains(v.name.decodedName.toString)
case _ => false
})).map {
case v: ValDef => v.name.toTermName
private def getInternalFieldsTermNameExcludeLocal(annotteeClassDefinitions: Seq[Tree]): Seq[TermName] = {
if (annotteeClassDefinitions.exists(f => isNotLocalClassMember(f))) {
c.info(c.enclosingPosition, s"There is a non private class definition inside the class", extractArgumentsDetail._1)
}
getClassMemberValDefs(annotteeClassDefinitions).filter(p => isNotLocalClassMember(p) &&
!extractArgumentsDetail._2.contains(p.name.decodedName.toString)).map(_.name.toTermName)
}
// equals method
private def getEqualsMethod(className: TypeName, termNames: Seq[TermName], superClasses: Seq[Tree], annotteeClassDefinitions: Seq[Tree]): Tree = {
val existsCanEqual = getClassMemberDefDefs(annotteeClassDefinitions) exists {
val existsCanEqual = getClassMemberDefDefs(annotteeClassDefinitions).exists {
case tree @ q"$mods def $tname[..$tparams](...$paramss): $tpt = $expr" if tname.asInstanceOf[TermName].decodedName.toString == "canEqual" && paramss.nonEmpty =>
if (!isNotLocalClassMember(tree)) {
c.info(c.enclosingPosition, "The canEqual method has been found in class, and method mods exists private[this] or protected[this]", extractArgumentsDetail._1)
}
val params = paramss.asInstanceOf[List[List[Tree]]].flatten.map(pp => getMethodParamName(pp))
params.exists(p => p.decodedName.toString == "Any")
case _ => false
......@@ -135,7 +131,7 @@ object equalsAndHashCodeMacro {
val allFieldsTermName = ctorFieldNames.map {
case v: ValDef => v.name.toTermName
}
val allTernNames = allFieldsTermName ++ getInternalFieldTermNameExcludeLocal(annotteeClassDefinitions)
val allTernNames = allFieldsTermName ++ getInternalFieldsTermNameExcludeLocal(annotteeClassDefinitions)
val hash = getHashcodeMethod(allTernNames, superClasses)
val equals = getEqualsMethod(className, allTernNames, superClasses, annotteeClassDefinitions)
c.Expr(
......
......@@ -100,7 +100,8 @@ class EqualsAndHashCodeTest extends AnyFlatSpec with Matchers {
"equals3" should "ok even if exists a canEqual" in {
@equalsAndHashCode
class Employee1(name: String, age: Int, var role: String) extends Person(name, age) {
override def canEqual(that: Any) = that.getClass == classOf[Employee1];
class A {}
override def canEqual(that: Any) = that.getClass == classOf[Employee1]
}
"""
| @equalsAndHashCode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册