diff options
| author | loyalsoldier <[email protected]> | 2021-08-27 18:27:16 +0800 |
|---|---|---|
| committer | loyalsoldier <[email protected]> | 2021-08-29 20:09:57 +0800 |
| commit | 85a343aca99d864c517f13cd3169ebcc910ec0d8 (patch) | |
| tree | eccfd3680d9dc6e22f265a9525dccac85902c2ab /lib | |
| parent | 2b32e8845d9e55b6c23ebb41bd0f382100094386 (diff) | |
Refactor: use plugin architecture to support multiple I/O formats
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/config.go | 128 | ||||
| -rw-r--r-- | lib/error.go | 14 | ||||
| -rw-r--r-- | lib/func.go | 59 | ||||
| -rw-r--r-- | lib/instance.go | 73 | ||||
| -rw-r--r-- | lib/lib.go | 478 |
5 files changed, 752 insertions, 0 deletions
diff --git a/lib/config.go b/lib/config.go new file mode 100644 index 00000000..20e5137a --- /dev/null +++ b/lib/config.go @@ -0,0 +1,128 @@ +package lib + +import ( + "encoding/json" + "errors" + "fmt" + "strings" +) + +var ( + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) +) + +type inputConfigCreator func(Action, json.RawMessage) (InputConverter, error) + +type outputConfigCreator func(Action, json.RawMessage) (OutputConverter, error) + +func RegisterInputConfigCreator(id string, fn inputConfigCreator) error { + id = strings.ToLower(id) + if _, found := inputConfigCreatorCache[id]; found { + return errors.New("config creator has already been registered") + } + inputConfigCreatorCache[id] = fn + return nil +} + +func createInputConfig(id string, action Action, data json.RawMessage) (InputConverter, error) { + id = strings.ToLower(id) + fn, found := inputConfigCreatorCache[id] + if !found { + return nil, errors.New("unknown config type") + } + return fn(action, data) +} + +func RegisterOutputConfigCreator(id string, fn outputConfigCreator) error { + id = strings.ToLower(id) + if _, found := outputConfigCreatorCache[id]; found { + return errors.New("config creator has already been registered") + } + outputConfigCreatorCache[id] = fn + return nil +} + +func createOutputConfig(id string, action Action, data json.RawMessage) (OutputConverter, error) { + id = strings.ToLower(id) + fn, found := outputConfigCreatorCache[id] + if !found { + return nil, errors.New("unknown config type") + } + return fn(action, data) +} + +type config struct { + Input []*inputConvConfig `json:"input"` + Output []*outputConvConfig `json:"output"` +} + +type inputConvConfig struct { + iType string + action Action + converter InputConverter +} + +func (i *inputConvConfig) UnmarshalJSON(data []byte) error { + var temp struct { + Type string `json:"type"` + Action Action `json:"action"` + Args json.RawMessage `json:"args"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if !ActionsRegistry[temp.Action] { + return fmt.Errorf("invalid action %s in type %s", temp.Action, temp.Type) + } + + config, err := createInputConfig(temp.Type, temp.Action, temp.Args) + if err != nil { + return err + } + + i.iType = config.GetType() + i.action = config.GetAction() + i.converter = config + + return nil +} + +type outputConvConfig struct { + iType string + action Action + converter OutputConverter +} + +func (i *outputConvConfig) UnmarshalJSON(data []byte) error { + var temp struct { + Type string `json:"type"` + Action Action `json:"action"` + Args json.RawMessage `json:"args"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if temp.Action == "" { + temp.Action = ActionOutput + } + + if !ActionsRegistry[temp.Action] { + return fmt.Errorf("invalid action %s in type %s", temp.Action, temp.Type) + } + + config, err := createOutputConfig(temp.Type, temp.Action, temp.Args) + if err != nil { + return err + } + + i.iType = config.GetType() + i.action = config.GetAction() + i.converter = config + + return nil +} diff --git a/lib/error.go b/lib/error.go new file mode 100644 index 00000000..c990d367 --- /dev/null +++ b/lib/error.go @@ -0,0 +1,14 @@ +package lib + +import "errors" + +var ( + ErrDuplicatedConverter = errors.New("duplicated converter") + ErrUnknownAction = errors.New("unknown action") + ErrNotSupportedFormat = errors.New("not supported format") + ErrInvalidIPType = errors.New("invalid IP type") + ErrInvalidIP = errors.New("invalid IP address") + ErrInvalidIPLength = errors.New("invalid IP address length") + ErrInvalidIPNet = errors.New("invalid IPNet address") + ErrInvalidPrefixType = errors.New("invalid prefix type") +) diff --git a/lib/func.go b/lib/func.go new file mode 100644 index 00000000..68a3e255 --- /dev/null +++ b/lib/func.go @@ -0,0 +1,59 @@ +package lib + +import ( + "fmt" + "io" + "net/http" + "strings" +) + +var ( + inputConverterMap = make(map[string]InputConverter) + outputConverterMap = make(map[string]OutputConverter) +) + +func ListInputConverter() { + fmt.Println("All available input formats:") + for name, ic := range inputConverterMap { + fmt.Printf(" - %s (%s)\n", name, ic.GetDescription()) + } +} + +func RegisterInputConverter(name string, c InputConverter) error { + name = strings.TrimSpace(name) + if _, ok := inputConverterMap[name]; ok { + return ErrDuplicatedConverter + } + inputConverterMap[name] = c + return nil +} + +func ListOutputConverter() { + fmt.Println("All available output formats:") + for name, oc := range outputConverterMap { + fmt.Printf(" - %s (%s)\n", name, oc.GetDescription()) + } +} + +func RegisterOutputConverter(name string, c OutputConverter) error { + name = strings.TrimSpace(name) + if _, ok := outputConverterMap[name]; ok { + return ErrDuplicatedConverter + } + outputConverterMap[name] = c + return nil +} + +func getRemoteURLContent(url string) ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get remote content -> %s: %s", url, resp.Status) + } + + return io.ReadAll(resp.Body) +} diff --git a/lib/instance.go b/lib/instance.go new file mode 100644 index 00000000..e7bcbf41 --- /dev/null +++ b/lib/instance.go @@ -0,0 +1,73 @@ +package lib + +import ( + "encoding/json" + "errors" + "os" + "strings" +) + +type Instance struct { + config *config + input []InputConverter + output []OutputConverter +} + +func NewInstance() (*Instance, error) { + return &Instance{ + config: new(config), + input: make([]InputConverter, 0), + output: make([]OutputConverter, 0), + }, nil +} + +func (i *Instance) Init(configFile string) error { + var content []byte + var err error + configFile = strings.TrimSpace(configFile) + if strings.HasPrefix(configFile, "http://") || strings.HasPrefix(configFile, "https://") { + content, err = getRemoteURLContent(configFile) + } else { + content, err = os.ReadFile(configFile) + } + if err != nil { + return err + } + + if err := json.Unmarshal(content, &i.config); err != nil { + return err + } + + for _, input := range i.config.Input { + i.input = append(i.input, input.converter) + } + + for _, output := range i.config.Output { + i.output = append(i.output, output.converter) + } + + return nil +} + +func (i *Instance) Run() error { + if len(i.input) == 0 || len(i.output) == 0 { + return errors.New("input type and output type must be specified") + } + + var err error + container := NewContainer() + for _, ic := range i.input { + container, err = ic.Input(container) + if err != nil { + return err + } + } + + for _, oc := range i.output { + if err := oc.Output(container); err != nil { + return err + } + } + + return nil +} diff --git a/lib/lib.go b/lib/lib.go new file mode 100644 index 00000000..c3203eeb --- /dev/null +++ b/lib/lib.go @@ -0,0 +1,478 @@ +package lib + +import ( + "fmt" + "log" + "net" + "strings" + "sync" + + "inet.af/netaddr" +) + +const ( + ActionAdd Action = "add" + ActionRemove Action = "remove" + ActionReplace Action = "replace" + ActionOutput Action = "output" + + IPv4 IPType = "ipv4" + IPv6 IPType = "ipv6" +) + +var ActionsRegistry = map[Action]bool{ + ActionAdd: true, + ActionRemove: true, + ActionReplace: true, + ActionOutput: true, +} + +type Action string + +type IPType string + +type Typer interface { + GetType() string +} + +type Actioner interface { + GetAction() Action +} + +type Descriptioner interface { + GetDescription() string +} + +type InputConverter interface { + Typer + Actioner + Descriptioner + Input(Container) (Container, error) +} + +type OutputConverter interface { + Typer + Actioner + Descriptioner + Output(Container) error +} + +type Entry struct { + name string + mu *sync.Mutex + ipv4Builder *netaddr.IPSetBuilder + ipv6Builder *netaddr.IPSetBuilder +} + +func NewEntry(name string) *Entry { + return &Entry{ + name: strings.ToUpper(strings.TrimSpace(name)), + mu: new(sync.Mutex), + ipv4Builder: new(netaddr.IPSetBuilder), + ipv6Builder: new(netaddr.IPSetBuilder), + } +} + +func (e *Entry) GetName() string { + return e.name +} + +func (e *Entry) hasIPv4Builder() bool { + return e.ipv4Builder != nil +} + +func (e *Entry) hasIPv6Builder() bool { + return e.ipv6Builder != nil +} + +func (e *Entry) processPrefix(src interface{}) (*netaddr.IPPrefix, IPType, error) { + switch src := src.(type) { + case net.IP: + ip, ok := netaddr.FromStdIP(src) + if !ok { + return nil, "", ErrInvalidIP + } + switch { + case ip.Is4(): + prefix := netaddr.IPPrefixFrom(ip, 32) + return &prefix, IPv4, nil + case ip.Is6(): + prefix := netaddr.IPPrefixFrom(ip, 128) + return &prefix, IPv6, nil + default: + return nil, "", ErrInvalidIPLength + } + + case *net.IPNet: + prefix, ok := netaddr.FromStdIPNet(src) + if !ok { + return nil, "", ErrInvalidIPNet + } + ip := prefix.IP() + switch { + case ip.Is4(): + return &prefix, IPv4, nil + case ip.Is6(): + return &prefix, IPv6, nil + default: + return nil, "", ErrInvalidIPLength + } + + case netaddr.IP: + switch { + case src.Is4(): + prefix := netaddr.IPPrefixFrom(src, 32) + return &prefix, IPv4, nil + case src.Is6(): + prefix := netaddr.IPPrefixFrom(src, 128) + return &prefix, IPv6, nil + default: + return nil, "", ErrInvalidIPLength + } + + case *netaddr.IP: + switch { + case src.Is4(): + prefix := netaddr.IPPrefixFrom(*src, 32) + return &prefix, IPv4, nil + case src.Is6(): + prefix := netaddr.IPPrefixFrom(*src, 128) + return &prefix, IPv6, nil + default: + return nil, "", ErrInvalidIPLength + } + + case netaddr.IPPrefix: + ip := src.IP() + switch { + case ip.Is4(): + return &src, IPv4, nil + case ip.Is6(): + return &src, IPv6, nil + default: + return nil, "", ErrInvalidIPLength + } + + case *netaddr.IPPrefix: + ip := src.IP() + switch { + case ip.Is4(): + return src, IPv4, nil + case ip.Is6(): + return src, IPv6, nil + default: + return nil, "", ErrInvalidIPLength + } + + case string: + _, network, err := net.ParseCIDR(src) + switch err { + case nil: + prefix, ok := netaddr.FromStdIPNet(network) + if !ok { + return nil, "", ErrInvalidIPNet + } + ip := prefix.IP() + switch { + case ip.Is4(): + return &prefix, IPv4, nil + case ip.Is6(): + return &prefix, IPv6, nil + default: + return nil, "", ErrInvalidIPLength + } + default: + ip, err := netaddr.ParseIP(src) + if err != nil { + return nil, "", err + } + switch { + case ip.Is4(): + prefix := netaddr.IPPrefixFrom(ip, 32) + return &prefix, IPv4, nil + case ip.Is6(): + prefix := netaddr.IPPrefixFrom(ip, 128) + return &prefix, IPv6, nil + default: + return nil, "", ErrInvalidIPLength + } + } + } + + return nil, "", ErrInvalidPrefixType +} + +func (e *Entry) add(prefix *netaddr.IPPrefix, ipType IPType) error { + e.mu.Lock() + defer e.mu.Unlock() + + switch ipType { + case IPv4: + if !e.hasIPv4Builder() { + e.ipv4Builder = new(netaddr.IPSetBuilder) + } + e.ipv4Builder.AddPrefix(*prefix) + case IPv6: + if !e.hasIPv6Builder() { + e.ipv6Builder = new(netaddr.IPSetBuilder) + } + e.ipv6Builder.AddPrefix(*prefix) + default: + return ErrInvalidIPType + } + + return nil +} + +func (e *Entry) remove(prefix *netaddr.IPPrefix, ipType IPType) error { + e.mu.Lock() + defer e.mu.Unlock() + + switch ipType { + case IPv4: + if e.hasIPv4Builder() { + e.ipv4Builder.RemovePrefix(*prefix) + } + case IPv6: + if e.hasIPv6Builder() { + e.ipv6Builder.RemovePrefix(*prefix) + } + default: + return ErrInvalidIPType + } + + return nil +} + +func (e *Entry) AddPrefix(cidr interface{}) error { + prefix, ipType, err := e.processPrefix(cidr) + if err != nil { + return err + } + if err := e.add(prefix, ipType); err != nil { + return err + } + return nil +} + +func (e *Entry) RemovePrefix(cidr string) error { + prefix, ipType, err := e.processPrefix(cidr) + if err != nil { + return err + } + if err := e.remove(prefix, ipType); err != nil { + return err + } + return nil +} + +func (e *Entry) MarshalText(opts ...IgnoreIPOption) ([]string, error) { + var ignoreIPType IPType + for _, opt := range opts { + if opt != nil { + ignoreIPType = opt() + } + } + disableIPv4, disableIPv6 := false, false + switch ignoreIPType { + case IPv4: + disableIPv4 = true + case IPv6: + disableIPv6 = true + } + + prefixSet := make([]string, 0, 1024) + + if !disableIPv4 && e.hasIPv4Builder() { + ipv4set, err := e.ipv4Builder.IPSet() + if err != nil { + return nil, err + } + prefixes := ipv4set.Prefixes() + for _, prefix := range prefixes { + prefixSet = append(prefixSet, prefix.String()) + } + } + + if !disableIPv6 && e.hasIPv6Builder() { + ipv6set, err := e.ipv6Builder.IPSet() + if err != nil { + return nil, err + } + prefixes := ipv6set.Prefixes() + for _, prefix := range prefixes { + prefixSet = append(prefixSet, prefix.String()) + } + } + + if len(prefixSet) > 0 { + return prefixSet, nil + } + + return nil, fmt.Errorf("entry %s has no prefix", e.GetName()) +} + +type IgnoreIPOption func() IPType + +func IgnoreIPv4() IPType { + return IPv4 +} + +func IgnoreIPv6() IPType { + return IPv6 +} + +type Container interface { + GetEntry(name string) (*Entry, bool) + Add(entry *Entry, opts ...IgnoreIPOption) error + Remove(name string, opts ...IgnoreIPOption) + Replace(entry *Entry, opts ...IgnoreIPOption) + Loop() <-chan *Entry +} + +type container struct { + entries *sync.Map // map[name]*Entry +} + +func NewContainer() Container { + return &container{ + entries: new(sync.Map), + } +} + +func (c *container) isValid() bool { + if c == nil || c.entries == nil { + return false + } + return true +} + +func (c *container) GetEntry(name string) (*Entry, bool) { + if !c.isValid() { + return nil, false + } + val, ok := c.entries.Load(strings.ToUpper(strings.TrimSpace(name))) + if !ok { + return nil, false + } + return val.(*Entry), true +} + +func (c *container) Loop() <-chan *Entry { + ch := make(chan *Entry, 300) + go func() { + c.entries.Range(func(key, value interface{}) bool { + ch <- value.(*Entry) + return true + }) + close(ch) + }() + return ch +} + +func (c *container) Add(entry *Entry, opts ...IgnoreIPOption) error { + var ignoreIPType IPType + for _, opt := range opts { + if opt != nil { + ignoreIPType = opt() + } + } + + name := entry.GetName() + val, found := c.GetEntry(name) + switch found { + case true: + var ipv4set, ipv6set *netaddr.IPSet + var err4, err6 error + if entry.hasIPv4Builder() { + ipv4set, err4 = entry.ipv4Builder.IPSet() + if err4 != nil { + return err4 + } + } + if entry.hasIPv6Builder() { + ipv6set, err6 = entry.ipv6Builder.IPSet() + if err6 != nil { + return err6 + } + } + switch ignoreIPType { + case IPv4: + if !val.hasIPv6Builder() { + val.ipv6Builder = new(netaddr.IPSetBuilder) + } + val.ipv6Builder.AddSet(ipv6set) + case IPv6: + if !val.hasIPv4Builder() { + val.ipv4Builder = new(netaddr.IPSetBuilder) + } + val.ipv4Builder.AddSet(ipv4set) + default: + if !val.hasIPv4Builder() { + val.ipv4Builder = new(netaddr.IPSetBuilder) + } + if !val.hasIPv6Builder() { + val.ipv6Builder = new(netaddr.IPSetBuilder) + } + val.ipv4Builder.AddSet(ipv4set) + val.ipv6Builder.AddSet(ipv6set) + } + c.entries.Store(name, val) + + case false: + switch ignoreIPType { + case IPv4: + entry.ipv4Builder = nil + case IPv6: + entry.ipv6Builder = nil + } + c.entries.Store(name, entry) + } + + return nil +} + +func (c *container) Remove(name string, opts ...IgnoreIPOption) { + val, found := c.GetEntry(name) + if !found { + log.Printf("failed to remove non-existent entry %s", name) + return + } + + var ignoreIPType IPType + for _, opt := range opts { + if opt != nil { + ignoreIPType = opt() + } + } + + switch ignoreIPType { + case IPv4: + val.ipv6Builder = nil + c.entries.Store(name, val) + case IPv6: + val.ipv4Builder = nil + c.entries.Store(name, val) + default: + c.entries.Delete(name) + } +} + +func (c *container) Replace(entry *Entry, opts ...IgnoreIPOption) { + var ignoreIPType IPType + for _, opt := range opts { + if opt != nil { + ignoreIPType = opt() + } + } + + switch ignoreIPType { + case IPv4: + entry.ipv4Builder = nil + case IPv6: + entry.ipv6Builder = nil + } + + name := entry.GetName() + c.entries.Store(name, entry) +} |
