Skip to content

Discriminated Union Pattern in Go

Published: at 09:23 PM

Discriminated Union Pattern in Go

In Typescript a relative common pattern is Discriminated Unions / Tagged Unions. This pattern is useful when dealing e.g. with events. It’s the formal name for having an input object coming in with a type field that distinguishes different object shapes from another. Every object shape has a unique value for the type field. This allows us to use a switch statement to determine the type of the object and act accordingly.

type ClickEvent = {
  type: 'click',
  x: number,
  y: number
}
type KeyPressEvent = {
  type: 'key_press',
  key: string
}
type ScrollEvent = {
  type: 'scroll',
  amount: number,
}
type Event = ClickEvent | KeyPressEvent | ScrollEvent;

function handleEvent(eventJson: string) {
  // This is a leap of faith that the API did not lie to us
  const event = JSON.parse(eventJson) as Event;

  switch (event.type) {
    case 'click': // We know that event is of type ClickEvent
      console.log(`Clicked at ${event.x}, ${event.y}`);
      break;
    case 'key_press': // We know that event is of type KeyPressEvent
      console.log(`Pressed key ${event.key}`);
      break;
    case 'scroll': // We know that event is of type ScrollEvent
      console.log(`Scrolled ${event.amount}`);
      break;
  }
}

In Go there are multiple ways to model the same behavior.

Parsing it both with a base and specific struct

A simple way to model this in Go is to parse the JSON twice: First we parse the JSON into a struct that only contains the type field. Based on the value of the type field we can then parse the JSON into the correct struct.

Go Playground

type BaseEvent struct {
	Type string `json:"type"`
}

type ClickEvent struct {
	BaseEvent
	X int `json:"x"`
	Y int `json:"y"`
}

type KeyPressEvent struct {
	BaseEvent
	Key string `json:"key"`
}

type ScrollEvent struct {
	BaseEvent
	Amount int `json:"amount"`
}

func HandleEvent(eventJson string) {
	var baseEvent BaseEvent

	// Parse the JSON into the base event
	// (No leap of faith here, we make sure the type field is here)
	if err := json.Unmarshal([]byte(eventJson), &baseEvent); err != nil {
		fmt.Printf("Failed to unmarshal base event: %v", err)
	}

	switch baseEvent.Type {
	case "click":
		var clickEvent ClickEvent
		if err := json.Unmarshal([]byte(eventJson), &clickEvent); err != nil {
			log.Fatalf("Failed to unmarshal click event: %v", err)
		}
		fmt.Printf("Clicked at %d, %d\n", clickEvent.X, clickEvent.Y)
	case "key_press":
		var keyPressEvent KeyPressEvent
		if err := json.Unmarshal([]byte(eventJson), &keyPressEvent); err != nil {
			log.Fatalf("Failed to unmarshal key press event: %v", err)
		}
		fmt.Printf("Pressed key %s\n", keyPressEvent.Key)
	case "scroll":
		var scrollEvent ScrollEvent
		if err := json.Unmarshal([]byte(eventJson), &scrollEvent); err != nil {
			log.Fatalf("Failed to unmarshal scroll event: %v", err)
		}
		fmt.Printf("Scrolled %d\n", scrollEvent.Amount)
	default:
		fmt.Printf("Unknown event type: %s", baseEvent.Type)
	}
}

This approach is relatively simple and easy to understand. However, it has the downside of parsing the JSON twice. This might be a problem if the JSON is large or parsing is expensive.

Using a combined base type that derives a specific subtype

Another way to model this in Go is to use a single struct that contains all the fields of all the specific structs. This struct can then have a method that returns the correct subtype based on the value of the type field. This way we only need to parse the JSON once, but we have more boilerplate code to copy over the necessary data.

Go Playground

type Event interface {
	String() string
}
type CatchAllEvent struct {
	Type   string `json:"type"`
	X      int    `json:"x"`
	Y      int    `json:"y"`
	Key    string `json:"key"`
	Amount int    `json:"amount"`
}

func (e CatchAllEvent) toSpecificEvent() Event {
	switch e.Type {
	case "click":
		return ClickEvent{X: e.X, Y: e.Y}
	case "key_press":
		return KeyPressEvent{Key: e.Key}
	case "scroll":
		return ScrollEvent{Amount: e.Amount}
	default:
		return nil
	}
}

type ClickEvent struct {
	X int `json:"x"`
	Y int `json:"y"`
}

var _ Event = ClickEvent{}

func (e ClickEvent) String() string {
	return fmt.Sprintf("Clicked at %d, %d", e.X, e.Y)
}

type KeyPressEvent struct {
	Key string `json:"key"`
}

var _ Event = KeyPressEvent{}

func (e KeyPressEvent) String() string {
	return fmt.Sprintf("Pressed key %s", e.Key)
}

type ScrollEvent struct {
	Amount int `json:"amount"`
}

var _ Event = ScrollEvent{}

func (e ScrollEvent) String() string {
	return fmt.Sprintf("Scrolled %d", e.Amount)
}

func HandleEvent(eventJson string) {
	// Parse JSON with CatchAllEvent
	var catchAllEvent CatchAllEvent
	if err := json.Unmarshal([]byte(eventJson), &catchAllEvent); err != nil {
		log.Fatalf("Failed to unmarshal event: %v", err)
	}

	specificEvent := catchAllEvent.toSpecificEvent()
	fmt.Println(specificEvent.String())
}

If the JSON is large or parsing is expensive, the second approach is more efficient. However, the first approach is easier to understand and maintain, therefore I would recommend using that one unless performance is a concern.