summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorloyalsoldier <[email protected]>2021-08-27 18:27:16 +0800
committerloyalsoldier <[email protected]>2021-08-29 20:09:57 +0800
commit85a343aca99d864c517f13cd3169ebcc910ec0d8 (patch)
treeeccfd3680d9dc6e22f265a9525dccac85902c2ab /lib
parent2b32e8845d9e55b6c23ebb41bd0f382100094386 (diff)
Refactor: use plugin architecture to support multiple I/O formats
Diffstat (limited to 'lib')
-rw-r--r--lib/config.go128
-rw-r--r--lib/error.go14
-rw-r--r--lib/func.go59
-rw-r--r--lib/instance.go73
-rw-r--r--lib/lib.go478
5 files changed, 752 insertions, 0 deletions
diff --git a/lib/config.go b/lib/config.go
new file mode 100644
index 00000000..20e5137a
--- /dev/null
+++ b/lib/config.go
@@ -0,0 +1,128 @@
+package lib
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strings"
+)
+
+var (
+ inputConfigCreatorCache = make(map[string]inputConfigCreator)
+ outputConfigCreatorCache = make(map[string]outputConfigCreator)
+)
+
+type inputConfigCreator func(Action, json.RawMessage) (InputConverter, error)
+
+type outputConfigCreator func(Action, json.RawMessage) (OutputConverter, error)
+
+func RegisterInputConfigCreator(id string, fn inputConfigCreator) error {
+ id = strings.ToLower(id)
+ if _, found := inputConfigCreatorCache[id]; found {
+ return errors.New("config creator has already been registered")
+ }
+ inputConfigCreatorCache[id] = fn
+ return nil
+}
+
+func createInputConfig(id string, action Action, data json.RawMessage) (InputConverter, error) {
+ id = strings.ToLower(id)
+ fn, found := inputConfigCreatorCache[id]
+ if !found {
+ return nil, errors.New("unknown config type")
+ }
+ return fn(action, data)
+}
+
+func RegisterOutputConfigCreator(id string, fn outputConfigCreator) error {
+ id = strings.ToLower(id)
+ if _, found := outputConfigCreatorCache[id]; found {
+ return errors.New("config creator has already been registered")
+ }
+ outputConfigCreatorCache[id] = fn
+ return nil
+}
+
+func createOutputConfig(id string, action Action, data json.RawMessage) (OutputConverter, error) {
+ id = strings.ToLower(id)
+ fn, found := outputConfigCreatorCache[id]
+ if !found {
+ return nil, errors.New("unknown config type")
+ }
+ return fn(action, data)
+}
+
+type config struct {
+ Input []*inputConvConfig `json:"input"`
+ Output []*outputConvConfig `json:"output"`
+}
+
+type inputConvConfig struct {
+ iType string
+ action Action
+ converter InputConverter
+}
+
+func (i *inputConvConfig) UnmarshalJSON(data []byte) error {
+ var temp struct {
+ Type string `json:"type"`
+ Action Action `json:"action"`
+ Args json.RawMessage `json:"args"`
+ }
+
+ if err := json.Unmarshal(data, &temp); err != nil {
+ return err
+ }
+
+ if !ActionsRegistry[temp.Action] {
+ return fmt.Errorf("invalid action %s in type %s", temp.Action, temp.Type)
+ }
+
+ config, err := createInputConfig(temp.Type, temp.Action, temp.Args)
+ if err != nil {
+ return err
+ }
+
+ i.iType = config.GetType()
+ i.action = config.GetAction()
+ i.converter = config
+
+ return nil
+}
+
+type outputConvConfig struct {
+ iType string
+ action Action
+ converter OutputConverter
+}
+
+func (i *outputConvConfig) UnmarshalJSON(data []byte) error {
+ var temp struct {
+ Type string `json:"type"`
+ Action Action `json:"action"`
+ Args json.RawMessage `json:"args"`
+ }
+
+ if err := json.Unmarshal(data, &temp); err != nil {
+ return err
+ }
+
+ if temp.Action == "" {
+ temp.Action = ActionOutput
+ }
+
+ if !ActionsRegistry[temp.Action] {
+ return fmt.Errorf("invalid action %s in type %s", temp.Action, temp.Type)
+ }
+
+ config, err := createOutputConfig(temp.Type, temp.Action, temp.Args)
+ if err != nil {
+ return err
+ }
+
+ i.iType = config.GetType()
+ i.action = config.GetAction()
+ i.converter = config
+
+ return nil
+}
diff --git a/lib/error.go b/lib/error.go
new file mode 100644
index 00000000..c990d367
--- /dev/null
+++ b/lib/error.go
@@ -0,0 +1,14 @@
+package lib
+
+import "errors"
+
+var (
+ ErrDuplicatedConverter = errors.New("duplicated converter")
+ ErrUnknownAction = errors.New("unknown action")
+ ErrNotSupportedFormat = errors.New("not supported format")
+ ErrInvalidIPType = errors.New("invalid IP type")
+ ErrInvalidIP = errors.New("invalid IP address")
+ ErrInvalidIPLength = errors.New("invalid IP address length")
+ ErrInvalidIPNet = errors.New("invalid IPNet address")
+ ErrInvalidPrefixType = errors.New("invalid prefix type")
+)
diff --git a/lib/func.go b/lib/func.go
new file mode 100644
index 00000000..68a3e255
--- /dev/null
+++ b/lib/func.go
@@ -0,0 +1,59 @@
+package lib
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+)
+
+var (
+ inputConverterMap = make(map[string]InputConverter)
+ outputConverterMap = make(map[string]OutputConverter)
+)
+
+func ListInputConverter() {
+ fmt.Println("All available input formats:")
+ for name, ic := range inputConverterMap {
+ fmt.Printf(" - %s (%s)\n", name, ic.GetDescription())
+ }
+}
+
+func RegisterInputConverter(name string, c InputConverter) error {
+ name = strings.TrimSpace(name)
+ if _, ok := inputConverterMap[name]; ok {
+ return ErrDuplicatedConverter
+ }
+ inputConverterMap[name] = c
+ return nil
+}
+
+func ListOutputConverter() {
+ fmt.Println("All available output formats:")
+ for name, oc := range outputConverterMap {
+ fmt.Printf(" - %s (%s)\n", name, oc.GetDescription())
+ }
+}
+
+func RegisterOutputConverter(name string, c OutputConverter) error {
+ name = strings.TrimSpace(name)
+ if _, ok := outputConverterMap[name]; ok {
+ return ErrDuplicatedConverter
+ }
+ outputConverterMap[name] = c
+ return nil
+}
+
+func getRemoteURLContent(url string) ([]byte, error) {
+ resp, err := http.Get(url)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get remote content -> %s: %s", url, resp.Status)
+ }
+
+ return io.ReadAll(resp.Body)
+}
diff --git a/lib/instance.go b/lib/instance.go
new file mode 100644
index 00000000..e7bcbf41
--- /dev/null
+++ b/lib/instance.go
@@ -0,0 +1,73 @@
+package lib
+
+import (
+ "encoding/json"
+ "errors"
+ "os"
+ "strings"
+)
+
+type Instance struct {
+ config *config
+ input []InputConverter
+ output []OutputConverter
+}
+
+func NewInstance() (*Instance, error) {
+ return &Instance{
+ config: new(config),
+ input: make([]InputConverter, 0),
+ output: make([]OutputConverter, 0),
+ }, nil
+}
+
+func (i *Instance) Init(configFile string) error {
+ var content []byte
+ var err error
+ configFile = strings.TrimSpace(configFile)
+ if strings.HasPrefix(configFile, "http://") || strings.HasPrefix(configFile, "https://") {
+ content, err = getRemoteURLContent(configFile)
+ } else {
+ content, err = os.ReadFile(configFile)
+ }
+ if err != nil {
+ return err
+ }
+
+ if err := json.Unmarshal(content, &i.config); err != nil {
+ return err
+ }
+
+ for _, input := range i.config.Input {
+ i.input = append(i.input, input.converter)
+ }
+
+ for _, output := range i.config.Output {
+ i.output = append(i.output, output.converter)
+ }
+
+ return nil
+}
+
+func (i *Instance) Run() error {
+ if len(i.input) == 0 || len(i.output) == 0 {
+ return errors.New("input type and output type must be specified")
+ }
+
+ var err error
+ container := NewContainer()
+ for _, ic := range i.input {
+ container, err = ic.Input(container)
+ if err != nil {
+ return err
+ }
+ }
+
+ for _, oc := range i.output {
+ if err := oc.Output(container); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/lib/lib.go b/lib/lib.go
new file mode 100644
index 00000000..c3203eeb
--- /dev/null
+++ b/lib/lib.go
@@ -0,0 +1,478 @@
+package lib
+
+import (
+ "fmt"
+ "log"
+ "net"
+ "strings"
+ "sync"
+
+ "inet.af/netaddr"
+)
+
+const (
+ ActionAdd Action = "add"
+ ActionRemove Action = "remove"
+ ActionReplace Action = "replace"
+ ActionOutput Action = "output"
+
+ IPv4 IPType = "ipv4"
+ IPv6 IPType = "ipv6"
+)
+
+var ActionsRegistry = map[Action]bool{
+ ActionAdd: true,
+ ActionRemove: true,
+ ActionReplace: true,
+ ActionOutput: true,
+}
+
+type Action string
+
+type IPType string
+
+type Typer interface {
+ GetType() string
+}
+
+type Actioner interface {
+ GetAction() Action
+}
+
+type Descriptioner interface {
+ GetDescription() string
+}
+
+type InputConverter interface {
+ Typer
+ Actioner
+ Descriptioner
+ Input(Container) (Container, error)
+}
+
+type OutputConverter interface {
+ Typer
+ Actioner
+ Descriptioner
+ Output(Container) error
+}
+
+type Entry struct {
+ name string
+ mu *sync.Mutex
+ ipv4Builder *netaddr.IPSetBuilder
+ ipv6Builder *netaddr.IPSetBuilder
+}
+
+func NewEntry(name string) *Entry {
+ return &Entry{
+ name: strings.ToUpper(strings.TrimSpace(name)),
+ mu: new(sync.Mutex),
+ ipv4Builder: new(netaddr.IPSetBuilder),
+ ipv6Builder: new(netaddr.IPSetBuilder),
+ }
+}
+
+func (e *Entry) GetName() string {
+ return e.name
+}
+
+func (e *Entry) hasIPv4Builder() bool {
+ return e.ipv4Builder != nil
+}
+
+func (e *Entry) hasIPv6Builder() bool {
+ return e.ipv6Builder != nil
+}
+
+func (e *Entry) processPrefix(src interface{}) (*netaddr.IPPrefix, IPType, error) {
+ switch src := src.(type) {
+ case net.IP:
+ ip, ok := netaddr.FromStdIP(src)
+ if !ok {
+ return nil, "", ErrInvalidIP
+ }
+ switch {
+ case ip.Is4():
+ prefix := netaddr.IPPrefixFrom(ip, 32)
+ return &prefix, IPv4, nil
+ case ip.Is6():
+ prefix := netaddr.IPPrefixFrom(ip, 128)
+ return &prefix, IPv6, nil
+ default:
+ return nil, "", ErrInvalidIPLength
+ }
+
+ case *net.IPNet:
+ prefix, ok := netaddr.FromStdIPNet(src)
+ if !ok {
+ return nil, "", ErrInvalidIPNet
+ }
+ ip := prefix.IP()
+ switch {
+ case ip.Is4():
+ return &prefix, IPv4, nil
+ case ip.Is6():
+ return &prefix, IPv6, nil
+ default:
+ return nil, "", ErrInvalidIPLength
+ }
+
+ case netaddr.IP:
+ switch {
+ case src.Is4():
+ prefix := netaddr.IPPrefixFrom(src, 32)
+ return &prefix, IPv4, nil
+ case src.Is6():
+ prefix := netaddr.IPPrefixFrom(src, 128)
+ return &prefix, IPv6, nil
+ default:
+ return nil, "", ErrInvalidIPLength
+ }
+
+ case *netaddr.IP:
+ switch {
+ case src.Is4():
+ prefix := netaddr.IPPrefixFrom(*src, 32)
+ return &prefix, IPv4, nil
+ case src.Is6():
+ prefix := netaddr.IPPrefixFrom(*src, 128)
+ return &prefix, IPv6, nil
+ default:
+ return nil, "", ErrInvalidIPLength
+ }
+
+ case netaddr.IPPrefix:
+ ip := src.IP()
+ switch {
+ case ip.Is4():
+ return &src, IPv4, nil
+ case ip.Is6():
+ return &src, IPv6, nil
+ default:
+ return nil, "", ErrInvalidIPLength
+ }
+
+ case *netaddr.IPPrefix:
+ ip := src.IP()
+ switch {
+ case ip.Is4():
+ return src, IPv4, nil
+ case ip.Is6():
+ return src, IPv6, nil
+ default:
+ return nil, "", ErrInvalidIPLength
+ }
+
+ case string:
+ _, network, err := net.ParseCIDR(src)
+ switch err {
+ case nil:
+ prefix, ok := netaddr.FromStdIPNet(network)
+ if !ok {
+ return nil, "", ErrInvalidIPNet
+ }
+ ip := prefix.IP()
+ switch {
+ case ip.Is4():
+ return &prefix, IPv4, nil
+ case ip.Is6():
+ return &prefix, IPv6, nil
+ default:
+ return nil, "", ErrInvalidIPLength
+ }
+ default:
+ ip, err := netaddr.ParseIP(src)
+ if err != nil {
+ return nil, "", err
+ }
+ switch {
+ case ip.Is4():
+ prefix := netaddr.IPPrefixFrom(ip, 32)
+ return &prefix, IPv4, nil
+ case ip.Is6():
+ prefix := netaddr.IPPrefixFrom(ip, 128)
+ return &prefix, IPv6, nil
+ default:
+ return nil, "", ErrInvalidIPLength
+ }
+ }
+ }
+
+ return nil, "", ErrInvalidPrefixType
+}
+
+func (e *Entry) add(prefix *netaddr.IPPrefix, ipType IPType) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch ipType {
+ case IPv4:
+ if !e.hasIPv4Builder() {
+ e.ipv4Builder = new(netaddr.IPSetBuilder)
+ }
+ e.ipv4Builder.AddPrefix(*prefix)
+ case IPv6:
+ if !e.hasIPv6Builder() {
+ e.ipv6Builder = new(netaddr.IPSetBuilder)
+ }
+ e.ipv6Builder.AddPrefix(*prefix)
+ default:
+ return ErrInvalidIPType
+ }
+
+ return nil
+}
+
+func (e *Entry) remove(prefix *netaddr.IPPrefix, ipType IPType) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch ipType {
+ case IPv4:
+ if e.hasIPv4Builder() {
+ e.ipv4Builder.RemovePrefix(*prefix)
+ }
+ case IPv6:
+ if e.hasIPv6Builder() {
+ e.ipv6Builder.RemovePrefix(*prefix)
+ }
+ default:
+ return ErrInvalidIPType
+ }
+
+ return nil
+}
+
+func (e *Entry) AddPrefix(cidr interface{}) error {
+ prefix, ipType, err := e.processPrefix(cidr)
+ if err != nil {
+ return err
+ }
+ if err := e.add(prefix, ipType); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (e *Entry) RemovePrefix(cidr string) error {
+ prefix, ipType, err := e.processPrefix(cidr)
+ if err != nil {
+ return err
+ }
+ if err := e.remove(prefix, ipType); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (e *Entry) MarshalText(opts ...IgnoreIPOption) ([]string, error) {
+ var ignoreIPType IPType
+ for _, opt := range opts {
+ if opt != nil {
+ ignoreIPType = opt()
+ }
+ }
+ disableIPv4, disableIPv6 := false, false
+ switch ignoreIPType {
+ case IPv4:
+ disableIPv4 = true
+ case IPv6:
+ disableIPv6 = true
+ }
+
+ prefixSet := make([]string, 0, 1024)
+
+ if !disableIPv4 && e.hasIPv4Builder() {
+ ipv4set, err := e.ipv4Builder.IPSet()
+ if err != nil {
+ return nil, err
+ }
+ prefixes := ipv4set.Prefixes()
+ for _, prefix := range prefixes {
+ prefixSet = append(prefixSet, prefix.String())
+ }
+ }
+
+ if !disableIPv6 && e.hasIPv6Builder() {
+ ipv6set, err := e.ipv6Builder.IPSet()
+ if err != nil {
+ return nil, err
+ }
+ prefixes := ipv6set.Prefixes()
+ for _, prefix := range prefixes {
+ prefixSet = append(prefixSet, prefix.String())
+ }
+ }
+
+ if len(prefixSet) > 0 {
+ return prefixSet, nil
+ }
+
+ return nil, fmt.Errorf("entry %s has no prefix", e.GetName())
+}
+
+type IgnoreIPOption func() IPType
+
+func IgnoreIPv4() IPType {
+ return IPv4
+}
+
+func IgnoreIPv6() IPType {
+ return IPv6
+}
+
+type Container interface {
+ GetEntry(name string) (*Entry, bool)
+ Add(entry *Entry, opts ...IgnoreIPOption) error
+ Remove(name string, opts ...IgnoreIPOption)
+ Replace(entry *Entry, opts ...IgnoreIPOption)
+ Loop() <-chan *Entry
+}
+
+type container struct {
+ entries *sync.Map // map[name]*Entry
+}
+
+func NewContainer() Container {
+ return &container{
+ entries: new(sync.Map),
+ }
+}
+
+func (c *container) isValid() bool {
+ if c == nil || c.entries == nil {
+ return false
+ }
+ return true
+}
+
+func (c *container) GetEntry(name string) (*Entry, bool) {
+ if !c.isValid() {
+ return nil, false
+ }
+ val, ok := c.entries.Load(strings.ToUpper(strings.TrimSpace(name)))
+ if !ok {
+ return nil, false
+ }
+ return val.(*Entry), true
+}
+
+func (c *container) Loop() <-chan *Entry {
+ ch := make(chan *Entry, 300)
+ go func() {
+ c.entries.Range(func(key, value interface{}) bool {
+ ch <- value.(*Entry)
+ return true
+ })
+ close(ch)
+ }()
+ return ch
+}
+
+func (c *container) Add(entry *Entry, opts ...IgnoreIPOption) error {
+ var ignoreIPType IPType
+ for _, opt := range opts {
+ if opt != nil {
+ ignoreIPType = opt()
+ }
+ }
+
+ name := entry.GetName()
+ val, found := c.GetEntry(name)
+ switch found {
+ case true:
+ var ipv4set, ipv6set *netaddr.IPSet
+ var err4, err6 error
+ if entry.hasIPv4Builder() {
+ ipv4set, err4 = entry.ipv4Builder.IPSet()
+ if err4 != nil {
+ return err4
+ }
+ }
+ if entry.hasIPv6Builder() {
+ ipv6set, err6 = entry.ipv6Builder.IPSet()
+ if err6 != nil {
+ return err6
+ }
+ }
+ switch ignoreIPType {
+ case IPv4:
+ if !val.hasIPv6Builder() {
+ val.ipv6Builder = new(netaddr.IPSetBuilder)
+ }
+ val.ipv6Builder.AddSet(ipv6set)
+ case IPv6:
+ if !val.hasIPv4Builder() {
+ val.ipv4Builder = new(netaddr.IPSetBuilder)
+ }
+ val.ipv4Builder.AddSet(ipv4set)
+ default:
+ if !val.hasIPv4Builder() {
+ val.ipv4Builder = new(netaddr.IPSetBuilder)
+ }
+ if !val.hasIPv6Builder() {
+ val.ipv6Builder = new(netaddr.IPSetBuilder)
+ }
+ val.ipv4Builder.AddSet(ipv4set)
+ val.ipv6Builder.AddSet(ipv6set)
+ }
+ c.entries.Store(name, val)
+
+ case false:
+ switch ignoreIPType {
+ case IPv4:
+ entry.ipv4Builder = nil
+ case IPv6:
+ entry.ipv6Builder = nil
+ }
+ c.entries.Store(name, entry)
+ }
+
+ return nil
+}
+
+func (c *container) Remove(name string, opts ...IgnoreIPOption) {
+ val, found := c.GetEntry(name)
+ if !found {
+ log.Printf("failed to remove non-existent entry %s", name)
+ return
+ }
+
+ var ignoreIPType IPType
+ for _, opt := range opts {
+ if opt != nil {
+ ignoreIPType = opt()
+ }
+ }
+
+ switch ignoreIPType {
+ case IPv4:
+ val.ipv6Builder = nil
+ c.entries.Store(name, val)
+ case IPv6:
+ val.ipv4Builder = nil
+ c.entries.Store(name, val)
+ default:
+ c.entries.Delete(name)
+ }
+}
+
+func (c *container) Replace(entry *Entry, opts ...IgnoreIPOption) {
+ var ignoreIPType IPType
+ for _, opt := range opts {
+ if opt != nil {
+ ignoreIPType = opt()
+ }
+ }
+
+ switch ignoreIPType {
+ case IPv4:
+ entry.ipv4Builder = nil
+ case IPv6:
+ entry.ipv6Builder = nil
+ }
+
+ name := entry.GetName()
+ c.entries.Store(name, entry)
+}