AST Manipulation and Analysis
The Abstract Syntax Tree (AST) is a tree representation of the syntactic structure of source code. In Go, the standard library provides powerful packages for parsing, analyzing, and manipulating Go code at the AST level, enabling sophisticated static analysis and code generation techniques.
Understanding Go’s AST Packages
Go’s standard library includes several packages for working with ASTs:
- go/parser: Parses Go source code into an AST
- go/ast: Defines the AST types and provides utilities for traversing and manipulating the tree
- go/token: Defines tokens and positions for source code representation
- go/printer: Converts an AST back to formatted Go source code
- go/types: Provides type checking and semantic analysis
Together, these packages form a comprehensive toolkit for analyzing and transforming Go code.
Parsing Go Code into an AST
The first step in AST manipulation is parsing Go source code into an AST:
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)
func main() {
// Source code to parse
src := `
package example
import "fmt"
// Greeter provides greeting functionality
type Greeter struct {
Name string
}
// Greet returns a greeting message
func (g *Greeter) Greet() string {
return fmt.Sprintf("Hello, %s!", g.Name)
}
`
// Create a file set for position information
fset := token.NewFileSet()
// Parse the source code
file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
if err != nil {
fmt.Printf("Error parsing source: %v\n", err)
return
}
// Print the AST structure
fmt.Println("Package name:", file.Name)
// Print imports
fmt.Println("\nImports:")
for _, imp := range file.Imports {
fmt.Printf(" %s\n", imp.Path.Value)
}
// Print type declarations
fmt.Println("\nType declarations:")
for _, decl := range file.Decls {
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.TYPE {
for _, spec := range genDecl.Specs {
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
fmt.Printf(" %s\n", typeSpec.Name)
}
}
}
}
// Print function declarations
fmt.Println("\nFunction declarations:")
for _, decl := range file.Decls {
if funcDecl, ok := decl.(*ast.FuncDecl); ok {
if funcDecl.Recv != nil {
// Method
if len(funcDecl.Recv.List) > 0 {
if starExpr, ok := funcDecl.Recv.List[0].Type.(*ast.StarExpr); ok {
if ident, ok := starExpr.X.(*ast.Ident); ok {
fmt.Printf(" Method: (%s).%s\n", ident.Name, funcDecl.Name)
}
}
}
} else {
// Function
fmt.Printf(" Function: %s\n", funcDecl.Name)
}
}
}
}
Output:
Package name: example
Imports:
"fmt"
Type declarations:
Greeter
Function declarations:
Method: (Greeter).Greet
This example demonstrates how to parse Go code and extract basic structural information from the AST.
AST Traversal and Visitor Pattern
The ast
package provides a powerful visitor pattern for traversing the AST:
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)
// FunctionVisitor implements the ast.Visitor interface
type FunctionVisitor struct {
Functions map[string]bool
}
// Visit is called for each node in the AST
func (v *FunctionVisitor) Visit(node ast.Node) ast.Visitor {
if node == nil {
return nil
}
// Check if the node is a function declaration
if funcDecl, ok := node.(*ast.FuncDecl); ok {
v.Functions[funcDecl.Name.Name] = true
}
return v
}
func main() {
// Source code to analyze
src := `
package example
func Add(a, b int) int {
return a + b
}
func Subtract(a, b int) int {
return a - b
}
func Multiply(a, b int) int {
return a * b
}
func Divide(a, b int) int {
if b == 0 {
panic("division by zero")
}
return a / b
}
`
// Create a file set for position information
fset := token.NewFileSet()
// Parse the source code
file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
if err != nil {
fmt.Printf("Error parsing source: %v\n", err)
return
}
// Create a visitor to find all functions
visitor := &FunctionVisitor{
Functions: make(map[string]bool),
}
// Walk the AST
ast.Walk(visitor, file)
// Print the functions found
fmt.Println("Functions found:")
for name := range visitor.Functions {
fmt.Printf(" - %s\n", name)
}
}
Output:
Functions found:
- Add
- Subtract
- Multiply
- Divide
The visitor pattern allows for complex traversals and analyses of the AST, making it possible to extract specific information or identify patterns in the code.
Static Analysis with AST
AST analysis enables powerful static analysis tools that can identify bugs, enforce coding standards, or gather metrics:
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"strings"
)
// ErrorCheckVisitor finds potential error handling issues
type ErrorCheckVisitor struct {
FileSet *token.FileSet
ErrorsFound []string
}
// Visit is called for each node in the AST
func (v *ErrorCheckVisitor) Visit(node ast.Node) ast.Visitor {
if node == nil {
return nil
}
// Look for function calls that might return errors
if callExpr, ok := node.(*ast.CallExpr); ok {
// Check if the call is assigned to variables
if parent, ok := callExpr.Parent.(*ast.AssignStmt); ok {
// Check if the assignment has multiple return values
if len(parent.Lhs) > 1 && len(parent.Rhs) == 1 {
// Check if the last variable is named "err" or similar
if lastVar, ok := parent.Lhs[len(parent.Lhs)-1].(*ast.Ident); ok {
if strings.Contains(lastVar.Name, "err") || lastVar.Name == "_" {
// Check if there's an if statement immediately after to check the error
if !v.hasErrorCheck(parent) {
pos := v.FileSet.Position(parent.Pos())
v.ErrorsFound = append(v.ErrorsFound, fmt.Sprintf(
"Line %d: Potential unchecked error from function call",
pos.Line,
))
}
}
}
}
}
}
return v
}
// hasErrorCheck checks if there's an if statement checking the error after the assignment
func (v *ErrorCheckVisitor) hasErrorCheck(assign *ast.AssignStmt) bool {
// This is a simplified check - a real implementation would be more thorough
if parent, ok := assign.Parent.(*ast.BlockStmt); ok {
for i, stmt := range parent.List {
if stmt == assign && i+1 < len(parent.List) {
if ifStmt, ok := parent.List[i+1].(*ast.IfStmt); ok {
// Check if the if condition involves the error variable
// This is a simplified check
return true
}
}
}
}
return false
}
func main() {
// Source code to analyze
src := `
package example
import "os"
func processFile(filename string) {
// Unchecked error
file, err := os.Open(filename)
file.Read([]byte{})
// Properly checked error
data, err := os.ReadFile(filename)
if err != nil {
return
}
// Error ignored with _
n, _ := file.Write(data)
}
`
// Create a file set for position information
fset := token.NewFileSet()
// Parse the source code
file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
if err != nil {
fmt.Printf("Error parsing source: %v\n", err)
return
}
// Create a visitor to find unchecked errors
visitor := &ErrorCheckVisitor{
FileSet: fset,
ErrorsFound: make([]string, 0),
}
// Walk the AST
ast.Walk(visitor, file)
// Print the errors found
if len(visitor.ErrorsFound) > 0 {
fmt.Println("Potential error handling issues:")
for _, err := range visitor.ErrorsFound {
fmt.Printf(" %s\n", err)
}
} else {
fmt.Println("No error handling issues found")
}
}
This example demonstrates a simplified static analyzer that identifies potential unchecked errors in Go code. Real-world static analyzers like errcheck
, staticcheck
, and golint
use similar techniques but with more sophisticated analysis.
AST Transformation and Code Generation
Beyond analysis, the AST can be modified to transform code or generate new code:
package main
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
)
// AddLoggingVisitor adds logging statements to function entries
type AddLoggingVisitor struct {
FileSet *token.FileSet
Modified bool
}
// Visit is called for each node in the AST
func (v *AddLoggingVisitor) Visit(node ast.Node) ast.Visitor {
if node == nil {
return nil
}
// Check if the node is a function declaration
if funcDecl, ok := node.(*ast.FuncDecl); ok {
// Skip functions without bodies (e.g., interfaces)
if funcDecl.Body == nil {
return v
}
// Create a logging statement
logStmt := &ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("fmt"),
Sel: ast.NewIdent("Printf"),
},
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf(`"Entering function %s\n"`, funcDecl.Name.Name),
},
},
},
}
// Add the logging statement at the beginning of the function body
funcDecl.Body.List = append(
[]ast.Stmt{logStmt},
funcDecl.Body.List...,
)
v.Modified = true
}
return v
}
func main() {
// Source code to transform
src := `
package example
import "fmt"
func Add(a, b int) int {
return a + b
}
func Multiply(a, b int) int {
return a * b
}
`
// Create a file set for position information
fset := token.NewFileSet()
// Parse the source code
file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
if err != nil {
fmt.Printf("Error parsing source: %v\n", err)
return
}
// Add import for fmt if not already present
addImport(file, "fmt")
// Create a visitor to add logging
visitor := &AddLoggingVisitor{
FileSet: fset,
Modified: false,
}
// Walk the AST
ast.Walk(visitor, file)
// Print the modified code
if visitor.Modified {
var buf bytes.Buffer
printer.Fprint(&buf, fset, file)
// Format the code
formattedCode, err := format.Source(buf.Bytes())
if err != nil {
fmt.Printf("Error formatting code: %v\n", err)
return
}
fmt.Println("Modified code:")
fmt.Println(string(formattedCode))
}
}
// addImport adds an import if it's not already present
func addImport(file *ast.File, importPath string) {
// Check if the import already exists
for _, imp := range file.Imports {
if imp.Path.Value == fmt.Sprintf(`"%s"`, importPath) {
return
}
}
// Create a new import
importSpec := &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf(`"%s"`, importPath),
},
}
// Find the import declaration or create a new one
var importDecl *ast.GenDecl
for _, decl := range file.Decls {
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
importDecl = genDecl
break
}
}
if importDecl == nil {
// Create a new import declaration
importDecl = &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{importSpec},
}
file.Decls = append([]ast.Decl{importDecl}, file.Decls...)
} else {
// Add to existing import declaration
importDecl.Specs = append(importDecl.Specs, importSpec)
}
}
Output:
Modified code:
package example
import "fmt"
func Add(a, b int) int {
fmt.Printf("Entering function Add\n")
return a + b
}
func Multiply(a, b int) int {
fmt.Printf("Entering function Multiply\n")
return a * b
}
This example demonstrates how to transform Go code by adding logging statements to all functions. Similar techniques can be used for more complex transformations like adding instrumentation, refactoring code, or implementing cross-cutting concerns.
Building a Custom Linter
AST analysis is particularly useful for building custom linters that enforce project-specific coding standards:
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"strings"
)
// LintRule defines a rule for linting
type LintRule interface {
Check(file *ast.File, fset *token.FileSet) []LintIssue
}
// LintIssue represents a linting issue
type LintIssue struct {
Position token.Position
Message string
}
// ExportedDocRule checks that exported declarations have documentation
type ExportedDocRule struct{}
func (r *ExportedDocRule) Check(file *ast.File, fset *token.FileSet) []LintIssue {
issues := []LintIssue{}
// Check all declarations
for _, decl := range file.Decls {
// Check functions
if funcDecl, ok := decl.(*ast.FuncDecl); ok {
if ast.IsExported(funcDecl.Name.Name) && funcDecl.Doc == nil {
issues = append(issues, LintIssue{
Position: fset.Position(funcDecl.Pos()),
Message: fmt.Sprintf("exported function %s should have a comment", funcDecl.Name.Name),
})
}
}
// Check type declarations
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.TYPE {
for _, spec := range genDecl.Specs {
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
if ast.IsExported(typeSpec.Name.Name) && genDecl.Doc == nil {
issues = append(issues, LintIssue{
Position: fset.Position(typeSpec.Pos()),
Message: fmt.Sprintf("exported type %s should have a comment", typeSpec.Name.Name),
})
}
}
}
}
}
return issues
}
// NameConventionRule checks naming conventions
type NameConventionRule struct{}
func (r *NameConventionRule) Check(file *ast.File, fset *token.FileSet) []LintIssue {
issues := []LintIssue{}
// Visit all identifiers
ast.Inspect(file, func(n ast.Node) bool {
// Check variable declarations
if varDecl, ok := n.(*ast.ValueSpec); ok {
for _, name := range varDecl.Names {
// Check if it's a constant in all caps
if parent, ok := varDecl.Parent.(*ast.GenDecl); ok && parent.Tok == token.CONST {
if ast.IsExported(name.Name) && !isAllCaps(name.Name) {
issues = append(issues, LintIssue{
Position: fset.Position(name.Pos()),
Message: fmt.Sprintf("exported constant %s should use ALL_CAPS", name.Name),
})
}
}
}
}
return true
})
return issues
}
// isAllCaps checks if a string is all uppercase with underscores
func isAllCaps(s string) bool {
return strings.ToUpper(s) == s && !strings.Contains(s, " ")
}
func main() {
// Source code to lint
src := `
package example
// This is a documented type
type DocumentedType struct {
Field string
}
// Missing documentation
type UndocumentedType struct {
Field string
}
func DocumentedFunction() {
// This function has documentation
}
func UndocumentedFunction() {
// This function is missing documentation
}
const (
CORRECT_CONSTANT = "value"
incorrectConstant = "value"
)
`
// Create a file set for position information
fset := token.NewFileSet()
// Parse the source code
file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
if err != nil {
fmt.Printf("Error parsing source: %v\n", err)
return
}
// Define linting rules
rules := []LintRule{
&ExportedDocRule{},
&NameConventionRule{},
}
// Apply all rules
allIssues := []LintIssue{}
for _, rule := range rules {
issues := rule.Check(file, fset)
allIssues = append(allIssues, issues...)
}
// Print the issues found
if len(allIssues) > 0 {
fmt.Printf("Found %d linting issues:\n", len(allIssues))
for _, issue := range allIssues {
fmt.Printf("%s: %s\n", issue.Position, issue.Message)
}
} else {
fmt.Println("No linting issues found")
}
}
This example demonstrates a simple linter that enforces documentation for exported declarations and naming conventions for constants. Real-world linters like golint
and staticcheck
use similar techniques but with more sophisticated rules and analyses.
AST-Based Code Generation
AST manipulation can be used to generate code based on existing code structures:
package main
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"strings"
)
// GenerateEqualsMethod adds an Equals method to struct types
func GenerateEqualsMethod(file *ast.File, fset *token.FileSet) (*ast.File, error) {
// Clone the file to avoid modifying the original
newFile := &ast.File{
Name: file.Name,
Decls: make([]ast.Decl, len(file.Decls)),
Scope: file.Scope,
Imports: file.Imports,
Comments: file.Comments,
}
copy(newFile.Decls, file.Decls)
// Find struct types
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}
// Generate Equals method for this struct
equalsMethod := generateEqualsMethodForStruct(typeSpec.Name.Name, structType)
newFile.Decls = append(newFile.Decls, equalsMethod)
}
}
return newFile, nil
}
// generateEqualsMethodForStruct creates an Equals method for a struct
func generateEqualsMethodForStruct(typeName string, structType *ast.StructType) *ast.FuncDecl {
// Create method receiver
receiver := &ast.FieldList{
List: []*ast.Field{
{
Names: []*ast.Ident{ast.NewIdent("s")},
Type: &ast.StarExpr{
X: ast.NewIdent(typeName),
},
},
},
}
// Create method parameters
params := &ast.FieldList{
List: []*ast.Field{
{
Names: []*ast.Ident{ast.NewIdent("other")},
Type: &ast.StarExpr{
X: ast.NewIdent(typeName),
},
},
},
}
// Create method return type
results := &ast.FieldList{
List: []*ast.Field{
{
Type: ast.NewIdent("bool"),
},
},
}
// Create method body
stmts := []ast.Stmt{
// if other == nil { return false }
&ast.IfStmt{
Cond: &ast.BinaryExpr{
X: ast.NewIdent("other"),
Op: token.EQL,
Y: ast.NewIdent("nil"),
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{ast.NewIdent("false")},
},
},
},
},
}
// Add field comparisons
for _, field := range structType.Fields.List {
if len(field.Names) == 0 {
// Skip embedded fields for simplicity
continue
}
for _, name := range field.Names {
// if s.Field != other.Field { return false }
stmts = append(stmts, &ast.IfStmt{
Cond: &ast.BinaryExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent("s"),
Sel: ast.NewIdent(name.Name),
},
Op: token.NEQ,
Y: &ast.SelectorExpr{
X: ast.NewIdent("other"),
Sel: ast.NewIdent(name.Name),
},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{ast.NewIdent("false")},
},
},
},
})
}
}
// Add final return true
stmts = append(stmts, &ast.ReturnStmt{
Results: []ast.Expr{ast.NewIdent("true")},
})
// Create the method
return &ast.FuncDecl{
Recv: receiver,
Name: ast.NewIdent("Equals"),
Type: &ast.FuncType{
Params: params,
Results: results,
},
Body: &ast.BlockStmt{
List: stmts,
},
}
}
func main() {
// Source code to transform
src := `
package example
// Person represents a person
type Person struct {
Name string
Age int
Address string
}
// Product represents a product
type Product struct {
ID string
Price float64
}
`
// Create a file set for position information
fset := token.NewFileSet()
// Parse the source code
file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
if err != nil {
fmt.Printf("Error parsing source: %v\n", err)
return
}
// Generate Equals methods
newFile, err := GenerateEqualsMethod(file, fset)
if err != nil {
fmt.Printf("Error generating methods: %v\n", err)
return
}
// Print the modified code
var buf bytes.Buffer
printer.Fprint(&buf, fset, newFile)
// Format the code
formattedCode, err := format.Source(buf.Bytes())
if err != nil {
fmt.Printf("Error formatting code: %v\n", err)
return
}
fmt.Println("Generated code:")
fmt.Println(string(formattedCode))
}
Output:
Generated code:
package example
// Person represents a person
type Person struct {
Name string
Age int
Address string
}
// Product represents a product
type Product struct {
ID string
Price float64
}
func (s *Person) Equals(other *Person) bool {
if other == nil {
return false
}
if s.Name != other.Name {
return false
}
if s.Age != other.Age {
return false
}
if s.Address != other.Address {
return false
}
return true
}
func (s *Product) Equals(other *Product) bool {
if other == nil {
return false
}
if s.ID != other.ID {
return false
}
if s.Price != other.Price {
return false
}
return true
}
This example demonstrates how to generate equality methods for struct types by analyzing the AST and creating new method declarations. Similar techniques can be used to generate other common methods like serialization, validation, or builder patterns.
AST manipulation provides a powerful way to analyze and transform Go code programmatically. In the next section, we’ll explore how to build complete code generation tools that combine these techniques into reusable packages.