提交 cc179734 编写于 作者: 梦境迷离's avatar 梦境迷离

refactor

上级 004baeff
......@@ -176,8 +176,8 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
/**
* Check whether the mods of the fields has a `private[this]` or `protected[this]`, because it cannot be used out of class.
*
* @param tree a field or method
* @return
* @param tree Tree is a field or method?
* @return false if mods exists private[this] or protected[this]
*/
def isNotLocalClassMember(tree: Tree): Boolean = {
lazy val modifierNotLocal = (mods: Modifiers) => {
......@@ -246,6 +246,26 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
}).map(_.asInstanceOf[ValDef])
}
/**
* Extract the constructor params ValDef and flatten for currying.
*
* @param annotteeClassParams
* @return {{ Seq(ValDef) }}
*/
def getClassConstructorValDefsFlatten(annotteeClassParams: List[List[Tree]]): Seq[ValDef] = {
annotteeClassParams.flatten.map(_.asInstanceOf[ValDef])
}
/**
* Extract the constructor params ValDef not flatten.
*
* @param annotteeClassParams
* @return {{ Seq(Seq(ValDef)) }}
*/
def getClassConstructorValDefsNotFlatten(annotteeClassParams: List[List[Tree]]): Seq[Seq[ValDef]] = {
annotteeClassParams.map(_.map(_.asInstanceOf[ValDef]))
}
/**
* Extract the methods belonging to the class, contains Secondary Constructor.
*
......@@ -268,9 +288,8 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @example {{ new TestClass12(i)(j)(k)(t) }}
*/
def getConstructorWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], isCase: Boolean): Tree = {
val allFieldsTermName = fieldss.map(f => f.map {
case v: ValDef => v.name.toTermName
})
val fieldssValDefNotFlatten = getClassConstructorValDefsNotFlatten(fieldss)
val allFieldsTermName = fieldssValDefNotFlatten.map(_.map(_.name.toTermName))
// not currying
val constructor = if (fieldss.isEmpty || fieldss.size == 1) {
q"${if (isCase) q"${typeName.toTermName}(..${allFieldsTermName.flatten})" else q"new $typeName(..${allFieldsTermName.flatten})"}"
......@@ -280,7 +299,6 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
if (isCase) q"${typeName.toTermName}(...$first)(...${allFieldsTermName.tail})"
else q"new $typeName(..$first)(...${allFieldsTermName.tail})"
}
c.info(c.enclosingPosition, s"getConstructorWithCurrying constructor: $constructor, paramss: $fieldss", force = true)
constructor
}
......@@ -303,7 +321,6 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
val first = allFieldsTermName.head
q"def apply[..$classTypeParams](..$first)(...${allFieldsTermName.tail}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase = false)}"
}
c.info(c.enclosingPosition, s"getApplyMethodWithCurrying constructor: $applyMethod, paramss: $fieldss", force = true)
applyMethod
}
......
......@@ -82,14 +82,14 @@ object builderMacro {
}
override def modifiedDeclaration(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = {
val (className, fieldss, classTypeParams) = classDecl match {
val (className, annotteeClassParams, classTypeParams) = classDecl match {
// @see https://scala-lang.org/files/archive/spec/2.13/05-classes-and-objects.html
case q"$mods class $tpname[..$tparams](...$paramss) extends ..$bases { ..$body }" =>
(tpname.asInstanceOf[TypeName], paramss.asInstanceOf[List[List[Tree]]], tparams.asInstanceOf[List[Tree]])
case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl")
}
val builder = getBuilderClassAndMethod(className, fieldss, classTypeParams, isCaseClass(classDecl))
val builder = getBuilderClassAndMethod(className, annotteeClassParams, classTypeParams, isCaseClass(classDecl))
val compDecl = modifiedCompanion(compDeclOpt, builder, className)
// Return both the class and companion object declarations
c.Expr(
......
......@@ -54,14 +54,13 @@ object constructorMacro {
* Extract the internal fields of members belonging to the class, but not in primary constructor and only `var`.
*/
private def getMemberVarDefTermNameWithType(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = {
getMutableValDefAndExcludeFields(annotteeClassDefinitions).map {
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) =>
if (v.tpt.isEmpty) { // val i = 1, tpt is `<type ?>`
// TODO getClass RETURN a java type, maybe we can try use class reflect to get the fields type name.
q"${v.name}: ${TypeName(toScalaType(evalTree(v.rhs).getClass.getTypeName))}"
} else {
q"${v.name}: ${v.tpt}"
}
getMutableValDefAndExcludeFields(annotteeClassDefinitions).map { v =>
if (v.tpt.isEmpty) { // val i = 1, tpt is `<type ?>`
// TODO getClass RETURN a java type, maybe we can try use class reflect to get the fields type name.
q"${v.name}: ${TypeName(toScalaType(evalTree(v.rhs).getClass.getTypeName))}"
} else {
q"${v.name}: ${v.tpt}"
}
}
}
......@@ -75,15 +74,8 @@ object constructorMacro {
c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} and the internal fields (declare as 'var') should not be Empty.")
}
// Extract the internal fields of members belonging to the class, but not in primary constructor.
val annotteeClassFieldNames = getMutableValDefAndExcludeFields(annotteeClassDefinitions).map {
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) => v.name
}
// Extract the field of the primary constructor.
val allFieldsTermName = annotteeClassParams.map(f => f.map {
case v: ValDef => v.name.toTermName
})
val annotteeClassFieldNames = getMutableValDefAndExcludeFields(annotteeClassDefinitions).map(_.name)
val allFieldsTermName = getClassConstructorValDefsNotFlatten(annotteeClassParams).map(_.map(_.name.toTermName))
// Extract the field of the primary constructor.
val classParamsNameWithType = getConstructorParamsNameWithType(annotteeClassParams.flatten)
val applyMethod = if (annotteeClassParams.isEmpty || annotteeClassParams.size == 1) {
......
......@@ -127,10 +127,7 @@ object equalsAndHashCodeMacro {
(tpname.asInstanceOf[TypeName], paramss.asInstanceOf[List[List[Tree]]], stats.asInstanceOf[Seq[Tree]], parents.asInstanceOf[Seq[Tree]])
case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl")
}
val ctorFieldNames = annotteeClassParams.flatten.filter(cf => isNotLocalClassMember(cf))
val allFieldsTermName = ctorFieldNames.map {
case v: ValDef => v.name.toTermName
}
val allFieldsTermName = getClassConstructorValDefsFlatten(annotteeClassParams).filter(cf => isNotLocalClassMember(cf)).map(_.name.toTermName)
val allTernNames = allFieldsTermName ++ getInternalFieldsTermNameExcludeLocal(annotteeClassDefinitions)
val hash = getHashcodeMethod(allTernNames, superClasses)
val equals = getEqualsMethod(className, allTernNames, superClasses, annotteeClassDefinitions)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册