AST Manipulation and Analysis

The Abstract Syntax Tree (AST) is a tree representation of the syntactic structure of source code. In Go, the standard library provides powerful packages for parsing, analyzing, and manipulating Go code at the AST level, enabling sophisticated static analysis and code generation techniques.

Understanding Go’s AST Packages

Go’s standard library includes several packages for working with ASTs:

  1. go/parser: Parses Go source code into an AST
  2. go/ast: Defines the AST types and provides utilities for traversing and manipulating the tree
  3. go/token: Defines tokens and positions for source code representation
  4. go/printer: Converts an AST back to formatted Go source code
  5. go/types: Provides type checking and semantic analysis

Together, these packages form a comprehensive toolkit for analyzing and transforming Go code.

Parsing Go Code into an AST

The first step in AST manipulation is parsing Go source code into an AST:

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
)

func main() {
	// Source code to parse
	src := `
package example

import "fmt"

// Greeter provides greeting functionality
type Greeter struct {
	Name string
}

// Greet returns a greeting message
func (g *Greeter) Greet() string {
	return fmt.Sprintf("Hello, %s!", g.Name)
}
`

	// Create a file set for position information
	fset := token.NewFileSet()
	
	// Parse the source code
	file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
	if err != nil {
		fmt.Printf("Error parsing source: %v\n", err)
		return
	}
	
	// Print the AST structure
	fmt.Println("Package name:", file.Name)
	
	// Print imports
	fmt.Println("\nImports:")
	for _, imp := range file.Imports {
		fmt.Printf("  %s\n", imp.Path.Value)
	}
	
	// Print type declarations
	fmt.Println("\nType declarations:")
	for _, decl := range file.Decls {
		if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.TYPE {
			for _, spec := range genDecl.Specs {
				if typeSpec, ok := spec.(*ast.TypeSpec); ok {
					fmt.Printf("  %s\n", typeSpec.Name)
				}
			}
		}
	}
	
	// Print function declarations
	fmt.Println("\nFunction declarations:")
	for _, decl := range file.Decls {
		if funcDecl, ok := decl.(*ast.FuncDecl); ok {
			if funcDecl.Recv != nil {
				// Method
				if len(funcDecl.Recv.List) > 0 {
					if starExpr, ok := funcDecl.Recv.List[0].Type.(*ast.StarExpr); ok {
						if ident, ok := starExpr.X.(*ast.Ident); ok {
							fmt.Printf("  Method: (%s).%s\n", ident.Name, funcDecl.Name)
						}
					}
				}
			} else {
				// Function
				fmt.Printf("  Function: %s\n", funcDecl.Name)
			}
		}
	}
}

Output:

Package name: example

Imports:
  "fmt"

Type declarations:
  Greeter

Function declarations:
  Method: (Greeter).Greet

This example demonstrates how to parse Go code and extract basic structural information from the AST.

AST Traversal and Visitor Pattern

The ast package provides a powerful visitor pattern for traversing the AST:

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
)

// FunctionVisitor implements the ast.Visitor interface
type FunctionVisitor struct {
	Functions map[string]bool
}

// Visit is called for each node in the AST
func (v *FunctionVisitor) Visit(node ast.Node) ast.Visitor {
	if node == nil {
		return nil
	}
	
	// Check if the node is a function declaration
	if funcDecl, ok := node.(*ast.FuncDecl); ok {
		v.Functions[funcDecl.Name.Name] = true
	}
	
	return v
}

func main() {
	// Source code to analyze
	src := `
package example

func Add(a, b int) int {
	return a + b
}

func Subtract(a, b int) int {
	return a - b
}

func Multiply(a, b int) int {
	return a * b
}

func Divide(a, b int) int {
	if b == 0 {
		panic("division by zero")
	}
	return a / b
}
`

	// Create a file set for position information
	fset := token.NewFileSet()
	
	// Parse the source code
	file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
	if err != nil {
		fmt.Printf("Error parsing source: %v\n", err)
		return
	}
	
	// Create a visitor to find all functions
	visitor := &FunctionVisitor{
		Functions: make(map[string]bool),
	}
	
	// Walk the AST
	ast.Walk(visitor, file)
	
	// Print the functions found
	fmt.Println("Functions found:")
	for name := range visitor.Functions {
		fmt.Printf("  - %s\n", name)
	}
}

Output:

Functions found:
  - Add
  - Subtract
  - Multiply
  - Divide

The visitor pattern allows for complex traversals and analyses of the AST, making it possible to extract specific information or identify patterns in the code.

Static Analysis with AST

AST analysis enables powerful static analysis tools that can identify bugs, enforce coding standards, or gather metrics:

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"strings"
)

// ErrorCheckVisitor finds potential error handling issues
type ErrorCheckVisitor struct {
	FileSet    *token.FileSet
	ErrorsFound []string
}

// Visit is called for each node in the AST
func (v *ErrorCheckVisitor) Visit(node ast.Node) ast.Visitor {
	if node == nil {
		return nil
	}
	
	// Look for function calls that might return errors
	if callExpr, ok := node.(*ast.CallExpr); ok {
		// Check if the call is assigned to variables
		if parent, ok := callExpr.Parent.(*ast.AssignStmt); ok {
			// Check if the assignment has multiple return values
			if len(parent.Lhs) > 1 && len(parent.Rhs) == 1 {
				// Check if the last variable is named "err" or similar
				if lastVar, ok := parent.Lhs[len(parent.Lhs)-1].(*ast.Ident); ok {
					if strings.Contains(lastVar.Name, "err") || lastVar.Name == "_" {
						// Check if there's an if statement immediately after to check the error
						if !v.hasErrorCheck(parent) {
							pos := v.FileSet.Position(parent.Pos())
							v.ErrorsFound = append(v.ErrorsFound, fmt.Sprintf(
								"Line %d: Potential unchecked error from function call",
								pos.Line,
							))
						}
					}
				}
			}
		}
	}
	
	return v
}

// hasErrorCheck checks if there's an if statement checking the error after the assignment
func (v *ErrorCheckVisitor) hasErrorCheck(assign *ast.AssignStmt) bool {
	// This is a simplified check - a real implementation would be more thorough
	if parent, ok := assign.Parent.(*ast.BlockStmt); ok {
		for i, stmt := range parent.List {
			if stmt == assign && i+1 < len(parent.List) {
				if ifStmt, ok := parent.List[i+1].(*ast.IfStmt); ok {
					// Check if the if condition involves the error variable
					// This is a simplified check
					return true
				}
			}
		}
	}
	return false
}

func main() {
	// Source code to analyze
	src := `
package example

import "os"

func processFile(filename string) {
	// Unchecked error
	file, err := os.Open(filename)
	file.Read([]byte{})
	
	// Properly checked error
	data, err := os.ReadFile(filename)
	if err != nil {
		return
	}
	
	// Error ignored with _
	n, _ := file.Write(data)
}
`

	// Create a file set for position information
	fset := token.NewFileSet()
	
	// Parse the source code
	file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
	if err != nil {
		fmt.Printf("Error parsing source: %v\n", err)
		return
	}
	
	// Create a visitor to find unchecked errors
	visitor := &ErrorCheckVisitor{
		FileSet:     fset,
		ErrorsFound: make([]string, 0),
	}
	
	// Walk the AST
	ast.Walk(visitor, file)
	
	// Print the errors found
	if len(visitor.ErrorsFound) > 0 {
		fmt.Println("Potential error handling issues:")
		for _, err := range visitor.ErrorsFound {
			fmt.Printf("  %s\n", err)
		}
	} else {
		fmt.Println("No error handling issues found")
	}
}

This example demonstrates a simplified static analyzer that identifies potential unchecked errors in Go code. Real-world static analyzers like errcheck, staticcheck, and golint use similar techniques but with more sophisticated analysis.

AST Transformation and Code Generation

Beyond analysis, the AST can be modified to transform code or generate new code:

package main

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/printer"
	"go/token"
)

// AddLoggingVisitor adds logging statements to function entries
type AddLoggingVisitor struct {
	FileSet *token.FileSet
	Modified bool
}

// Visit is called for each node in the AST
func (v *AddLoggingVisitor) Visit(node ast.Node) ast.Visitor {
	if node == nil {
		return nil
	}
	
	// Check if the node is a function declaration
	if funcDecl, ok := node.(*ast.FuncDecl); ok {
		// Skip functions without bodies (e.g., interfaces)
		if funcDecl.Body == nil {
			return v
		}
		
		// Create a logging statement
		logStmt := &ast.ExprStmt{
			X: &ast.CallExpr{
				Fun: &ast.SelectorExpr{
					X:   ast.NewIdent("fmt"),
					Sel: ast.NewIdent("Printf"),
				},
				Args: []ast.Expr{
					&ast.BasicLit{
						Kind:  token.STRING,
						Value: fmt.Sprintf(`"Entering function %s\n"`, funcDecl.Name.Name),
					},
				},
			},
		}
		
		// Add the logging statement at the beginning of the function body
		funcDecl.Body.List = append(
			[]ast.Stmt{logStmt},
			funcDecl.Body.List...,
		)
		
		v.Modified = true
	}
	
	return v
}

func main() {
	// Source code to transform
	src := `
package example

import "fmt"

func Add(a, b int) int {
	return a + b
}

func Multiply(a, b int) int {
	return a * b
}
`

	// Create a file set for position information
	fset := token.NewFileSet()
	
	// Parse the source code
	file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
	if err != nil {
		fmt.Printf("Error parsing source: %v\n", err)
		return
	}
	
	// Add import for fmt if not already present
	addImport(file, "fmt")
	
	// Create a visitor to add logging
	visitor := &AddLoggingVisitor{
		FileSet:  fset,
		Modified: false,
	}
	
	// Walk the AST
	ast.Walk(visitor, file)
	
	// Print the modified code
	if visitor.Modified {
		var buf bytes.Buffer
		printer.Fprint(&buf, fset, file)
		
		// Format the code
		formattedCode, err := format.Source(buf.Bytes())
		if err != nil {
			fmt.Printf("Error formatting code: %v\n", err)
			return
		}
		
		fmt.Println("Modified code:")
		fmt.Println(string(formattedCode))
	}
}

// addImport adds an import if it's not already present
func addImport(file *ast.File, importPath string) {
	// Check if the import already exists
	for _, imp := range file.Imports {
		if imp.Path.Value == fmt.Sprintf(`"%s"`, importPath) {
			return
		}
	}
	
	// Create a new import
	importSpec := &ast.ImportSpec{
		Path: &ast.BasicLit{
			Kind:  token.STRING,
			Value: fmt.Sprintf(`"%s"`, importPath),
		},
	}
	
	// Find the import declaration or create a new one
	var importDecl *ast.GenDecl
	for _, decl := range file.Decls {
		if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
			importDecl = genDecl
			break
		}
	}
	
	if importDecl == nil {
		// Create a new import declaration
		importDecl = &ast.GenDecl{
			Tok:   token.IMPORT,
			Specs: []ast.Spec{importSpec},
		}
		file.Decls = append([]ast.Decl{importDecl}, file.Decls...)
	} else {
		// Add to existing import declaration
		importDecl.Specs = append(importDecl.Specs, importSpec)
	}
}

Output:

Modified code:
package example

import "fmt"

func Add(a, b int) int {
	fmt.Printf("Entering function Add\n")
	return a + b
}

func Multiply(a, b int) int {
	fmt.Printf("Entering function Multiply\n")
	return a * b
}

This example demonstrates how to transform Go code by adding logging statements to all functions. Similar techniques can be used for more complex transformations like adding instrumentation, refactoring code, or implementing cross-cutting concerns.

Building a Custom Linter

AST analysis is particularly useful for building custom linters that enforce project-specific coding standards:

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"strings"
)

// LintRule defines a rule for linting
type LintRule interface {
	Check(file *ast.File, fset *token.FileSet) []LintIssue
}

// LintIssue represents a linting issue
type LintIssue struct {
	Position token.Position
	Message  string
}

// ExportedDocRule checks that exported declarations have documentation
type ExportedDocRule struct{}

func (r *ExportedDocRule) Check(file *ast.File, fset *token.FileSet) []LintIssue {
	issues := []LintIssue{}
	
	// Check all declarations
	for _, decl := range file.Decls {
		// Check functions
		if funcDecl, ok := decl.(*ast.FuncDecl); ok {
			if ast.IsExported(funcDecl.Name.Name) && funcDecl.Doc == nil {
				issues = append(issues, LintIssue{
					Position: fset.Position(funcDecl.Pos()),
					Message:  fmt.Sprintf("exported function %s should have a comment", funcDecl.Name.Name),
				})
			}
		}
		
		// Check type declarations
		if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.TYPE {
			for _, spec := range genDecl.Specs {
				if typeSpec, ok := spec.(*ast.TypeSpec); ok {
					if ast.IsExported(typeSpec.Name.Name) && genDecl.Doc == nil {
						issues = append(issues, LintIssue{
							Position: fset.Position(typeSpec.Pos()),
							Message:  fmt.Sprintf("exported type %s should have a comment", typeSpec.Name.Name),
						})
					}
				}
			}
		}
	}
	
	return issues
}

// NameConventionRule checks naming conventions
type NameConventionRule struct{}

func (r *NameConventionRule) Check(file *ast.File, fset *token.FileSet) []LintIssue {
	issues := []LintIssue{}
	
	// Visit all identifiers
	ast.Inspect(file, func(n ast.Node) bool {
		// Check variable declarations
		if varDecl, ok := n.(*ast.ValueSpec); ok {
			for _, name := range varDecl.Names {
				// Check if it's a constant in all caps
				if parent, ok := varDecl.Parent.(*ast.GenDecl); ok && parent.Tok == token.CONST {
					if ast.IsExported(name.Name) && !isAllCaps(name.Name) {
						issues = append(issues, LintIssue{
							Position: fset.Position(name.Pos()),
							Message:  fmt.Sprintf("exported constant %s should use ALL_CAPS", name.Name),
						})
					}
				}
			}
		}
		
		return true
	})
	
	return issues
}

// isAllCaps checks if a string is all uppercase with underscores
func isAllCaps(s string) bool {
	return strings.ToUpper(s) == s && !strings.Contains(s, " ")
}

func main() {
	// Source code to lint
	src := `
package example

// This is a documented type
type DocumentedType struct {
	Field string
}

// Missing documentation
type UndocumentedType struct {
	Field string
}

func DocumentedFunction() {
	// This function has documentation
}

func UndocumentedFunction() {
	// This function is missing documentation
}

const (
	CORRECT_CONSTANT = "value"
	incorrectConstant = "value"
)
`

	// Create a file set for position information
	fset := token.NewFileSet()
	
	// Parse the source code
	file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
	if err != nil {
		fmt.Printf("Error parsing source: %v\n", err)
		return
	}
	
	// Define linting rules
	rules := []LintRule{
		&ExportedDocRule{},
		&NameConventionRule{},
	}
	
	// Apply all rules
	allIssues := []LintIssue{}
	for _, rule := range rules {
		issues := rule.Check(file, fset)
		allIssues = append(allIssues, issues...)
	}
	
	// Print the issues found
	if len(allIssues) > 0 {
		fmt.Printf("Found %d linting issues:\n", len(allIssues))
		for _, issue := range allIssues {
			fmt.Printf("%s: %s\n", issue.Position, issue.Message)
		}
	} else {
		fmt.Println("No linting issues found")
	}
}

This example demonstrates a simple linter that enforces documentation for exported declarations and naming conventions for constants. Real-world linters like golint and staticcheck use similar techniques but with more sophisticated rules and analyses.

AST-Based Code Generation

AST manipulation can be used to generate code based on existing code structures:

package main

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/printer"
	"go/token"
	"strings"
)

// GenerateEqualsMethod adds an Equals method to struct types
func GenerateEqualsMethod(file *ast.File, fset *token.FileSet) (*ast.File, error) {
	// Clone the file to avoid modifying the original
	newFile := &ast.File{
		Name:    file.Name,
		Decls:   make([]ast.Decl, len(file.Decls)),
		Scope:   file.Scope,
		Imports: file.Imports,
		Comments: file.Comments,
	}
	copy(newFile.Decls, file.Decls)
	
	// Find struct types
	for _, decl := range file.Decls {
		genDecl, ok := decl.(*ast.GenDecl)
		if !ok || genDecl.Tok != token.TYPE {
			continue
		}
		
		for _, spec := range genDecl.Specs {
			typeSpec, ok := spec.(*ast.TypeSpec)
			if !ok {
				continue
			}
			
			structType, ok := typeSpec.Type.(*ast.StructType)
			if !ok {
				continue
			}
			
			// Generate Equals method for this struct
			equalsMethod := generateEqualsMethodForStruct(typeSpec.Name.Name, structType)
			newFile.Decls = append(newFile.Decls, equalsMethod)
		}
	}
	
	return newFile, nil
}

// generateEqualsMethodForStruct creates an Equals method for a struct
func generateEqualsMethodForStruct(typeName string, structType *ast.StructType) *ast.FuncDecl {
	// Create method receiver
	receiver := &ast.FieldList{
		List: []*ast.Field{
			{
				Names: []*ast.Ident{ast.NewIdent("s")},
				Type: &ast.StarExpr{
					X: ast.NewIdent(typeName),
				},
			},
		},
	}
	
	// Create method parameters
	params := &ast.FieldList{
		List: []*ast.Field{
			{
				Names: []*ast.Ident{ast.NewIdent("other")},
				Type: &ast.StarExpr{
					X: ast.NewIdent(typeName),
				},
			},
		},
	}
	
	// Create method return type
	results := &ast.FieldList{
		List: []*ast.Field{
			{
				Type: ast.NewIdent("bool"),
			},
		},
	}
	
	// Create method body
	stmts := []ast.Stmt{
		// if other == nil { return false }
		&ast.IfStmt{
			Cond: &ast.BinaryExpr{
				X:  ast.NewIdent("other"),
				Op: token.EQL,
				Y:  ast.NewIdent("nil"),
			},
			Body: &ast.BlockStmt{
				List: []ast.Stmt{
					&ast.ReturnStmt{
						Results: []ast.Expr{ast.NewIdent("false")},
					},
				},
			},
		},
	}
	
	// Add field comparisons
	for _, field := range structType.Fields.List {
		if len(field.Names) == 0 {
			// Skip embedded fields for simplicity
			continue
		}
		
		for _, name := range field.Names {
			// if s.Field != other.Field { return false }
			stmts = append(stmts, &ast.IfStmt{
				Cond: &ast.BinaryExpr{
					X: &ast.SelectorExpr{
						X:   ast.NewIdent("s"),
						Sel: ast.NewIdent(name.Name),
					},
					Op: token.NEQ,
					Y: &ast.SelectorExpr{
						X:   ast.NewIdent("other"),
						Sel: ast.NewIdent(name.Name),
					},
				},
				Body: &ast.BlockStmt{
					List: []ast.Stmt{
						&ast.ReturnStmt{
							Results: []ast.Expr{ast.NewIdent("false")},
						},
					},
				},
			})
		}
	}
	
	// Add final return true
	stmts = append(stmts, &ast.ReturnStmt{
		Results: []ast.Expr{ast.NewIdent("true")},
	})
	
	// Create the method
	return &ast.FuncDecl{
		Recv: receiver,
		Name: ast.NewIdent("Equals"),
		Type: &ast.FuncType{
			Params:  params,
			Results: results,
		},
		Body: &ast.BlockStmt{
			List: stmts,
		},
	}
}

func main() {
	// Source code to transform
	src := `
package example

// Person represents a person
type Person struct {
	Name    string
	Age     int
	Address string
}

// Product represents a product
type Product struct {
	ID    string
	Price float64
}
`

	// Create a file set for position information
	fset := token.NewFileSet()
	
	// Parse the source code
	file, err := parser.ParseFile(fset, "example.go", src, parser.ParseComments)
	if err != nil {
		fmt.Printf("Error parsing source: %v\n", err)
		return
	}
	
	// Generate Equals methods
	newFile, err := GenerateEqualsMethod(file, fset)
	if err != nil {
		fmt.Printf("Error generating methods: %v\n", err)
		return
	}
	
	// Print the modified code
	var buf bytes.Buffer
	printer.Fprint(&buf, fset, newFile)
	
	// Format the code
	formattedCode, err := format.Source(buf.Bytes())
	if err != nil {
		fmt.Printf("Error formatting code: %v\n", err)
		return
	}
	
	fmt.Println("Generated code:")
	fmt.Println(string(formattedCode))
}

Output:

Generated code:
package example

// Person represents a person
type Person struct {
	Name    string
	Age     int
	Address string
}

// Product represents a product
type Product struct {
	ID    string
	Price float64
}

func (s *Person) Equals(other *Person) bool {
	if other == nil {
		return false
	}
	if s.Name != other.Name {
		return false
	}
	if s.Age != other.Age {
		return false
	}
	if s.Address != other.Address {
		return false
	}
	return true
}

func (s *Product) Equals(other *Product) bool {
	if other == nil {
		return false
	}
	if s.ID != other.ID {
		return false
	}
	if s.Price != other.Price {
		return false
	}
	return true
}

This example demonstrates how to generate equality methods for struct types by analyzing the AST and creating new method declarations. Similar techniques can be used to generate other common methods like serialization, validation, or builder patterns.

AST manipulation provides a powerful way to analyze and transform Go code programmatically. In the next section, we’ll explore how to build complete code generation tools that combine these techniques into reusable packages.