diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/common_test.go | 139 | ||||
| -rw-r--r-- | lib/config_test.go | 130 | ||||
| -rw-r--r-- | lib/container_test.go | 445 | ||||
| -rw-r--r-- | lib/converter_test.go | 44 | ||||
| -rw-r--r-- | lib/entry_test.go | 414 | ||||
| -rw-r--r-- | lib/helpers_test.go | 100 | ||||
| -rw-r--r-- | lib/instance_test.go | 143 | ||||
| -rw-r--r-- | lib/lib_test.go | 12 |
8 files changed, 1427 insertions, 0 deletions
diff --git a/lib/common_test.go b/lib/common_test.go new file mode 100644 index 00000000..71d55cc5 --- /dev/null +++ b/lib/common_test.go @@ -0,0 +1,139 @@ +package lib + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetRemoteURLContent(t *testing.T) { + t.Run("success", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello")) + })) + defer s.Close() + + data, err := GetRemoteURLContent(s.URL) + if err != nil { + t.Fatalf("GetRemoteURLContent() error = %v", err) + } + if string(data) != "hello" { + t.Fatalf("GetRemoteURLContent() = %s, want %s", data, "hello") + } + }) + + t.Run("status error", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + defer s.Close() + + if _, err := GetRemoteURLContent(s.URL); err == nil { + t.Fatalf("expected error for non-200 response") + } + }) + + t.Run("request error", func(t *testing.T) { + if _, err := GetRemoteURLContent("http://[%"); err == nil { + t.Fatalf("expected error for invalid URL") + } + }) +} + +func TestGetRemoteURLReader(t *testing.T) { + t.Run("success", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("world")) + })) + defer s.Close() + + rc, err := GetRemoteURLReader(s.URL) + if err != nil { + t.Fatalf("GetRemoteURLReader() error = %v", err) + } + defer rc.Close() + + data, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("unexpected read error: %v", err) + } + if string(data) != "world" { + t.Fatalf("GetRemoteURLReader() = %s, want %s", data, "world") + } + }) + + t.Run("status error", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + defer s.Close() + + if rc, err := GetRemoteURLReader(s.URL); err == nil { + rc.Close() + t.Fatalf("expected error for non-200 response") + } + }) + + t.Run("request error", func(t *testing.T) { + if _, err := GetRemoteURLReader("http://[%"); err == nil { + t.Fatalf("expected error for invalid URL") + } + }) +} + +func TestGetIgnoreIPType(t *testing.T) { + if fn := GetIgnoreIPType(IPv4); fn == nil || fn() != IPv6 { + t.Fatalf("GetIgnoreIPType(IPv4) = %v", fn) + } + if fn := GetIgnoreIPType(IPv6); fn == nil || fn() != IPv4 { + t.Fatalf("GetIgnoreIPType(IPv6) = %v", fn) + } + if fn := GetIgnoreIPType(IPType("other")); fn != nil { + t.Fatalf("GetIgnoreIPType(other) = %v, want nil", fn) + } +} + +func TestWantedListExtendedUnmarshalJSON(t *testing.T) { + t.Run("slice input", func(t *testing.T) { + var w WantedListExtended + if err := w.UnmarshalJSON([]byte(`["a","b"]`)); err != nil { + t.Fatalf("UnmarshalJSON() error = %v", err) + } + if len(w.TypeSlice) != 2 || w.TypeSlice[0] != "a" || w.TypeSlice[1] != "b" { + t.Fatalf("TypeSlice = %#v", w.TypeSlice) + } + if len(w.TypeMap) != 0 { + t.Fatalf("TypeMap should be empty, got %#v", w.TypeMap) + } + }) + + t.Run("map input", func(t *testing.T) { + var w WantedListExtended + if err := w.UnmarshalJSON([]byte(`{"x":["y"]}`)); err != nil { + t.Fatalf("UnmarshalJSON() error = %v", err) + } + if len(w.TypeSlice) != 0 { + t.Fatalf("TypeSlice should be empty, got %#v", w.TypeSlice) + } + if got := w.TypeMap["x"]; len(got) != 1 || got[0] != "y" { + t.Fatalf("TypeMap = %#v", w.TypeMap) + } + }) + + t.Run("invalid input", func(t *testing.T) { + var w WantedListExtended + if err := w.UnmarshalJSON([]byte(`123`)); err == nil { + t.Fatalf("expected error for invalid json") + } + }) + + t.Run("empty input", func(t *testing.T) { + var w WantedListExtended + if err := w.UnmarshalJSON([]byte(``)); err != nil { + t.Fatalf("UnmarshalJSON() error = %v", err) + } + }) +} diff --git a/lib/config_test.go b/lib/config_test.go new file mode 100644 index 00000000..41c62764 --- /dev/null +++ b/lib/config_test.go @@ -0,0 +1,130 @@ +package lib + +import ( + "encoding/json" + "testing" +) + +func TestRegisterInputConfigCreator(t *testing.T) { + resetConfigCreators() + + if err := RegisterInputConfigCreator("sample", func(a Action, data json.RawMessage) (InputConverter, error) { + return mockInputConverter{typ: "sample", action: a}, nil + }); err != nil { + t.Fatalf("RegisterInputConfigCreator() error = %v", err) + } + + if err := RegisterInputConfigCreator("sample", nil); err == nil { + t.Fatalf("expected error for duplicated creator") + } +} + +func TestCreateInputConfig(t *testing.T) { + resetConfigCreators() + + if _, err := createInputConfig("unknown", ActionAdd, nil); err == nil { + t.Fatalf("expected error for unknown config type") + } + + _ = RegisterInputConfigCreator("known", func(a Action, data json.RawMessage) (InputConverter, error) { + return mockInputConverter{typ: "known", action: a}, nil + }) + + cfg, err := createInputConfig("known", ActionRemove, nil) + if err != nil { + t.Fatalf("createInputConfig() error = %v", err) + } + if cfg.GetAction() != ActionRemove || cfg.GetType() != "known" { + t.Fatalf("unexpected converter: %v %v", cfg.GetType(), cfg.GetAction()) + } +} + +func TestRegisterOutputConfigCreator(t *testing.T) { + resetConfigCreators() + + if err := RegisterOutputConfigCreator("sample", func(a Action, data json.RawMessage) (OutputConverter, error) { + return mockOutputConverter{typ: "sample", action: a}, nil + }); err != nil { + t.Fatalf("RegisterOutputConfigCreator() error = %v", err) + } + + if err := RegisterOutputConfigCreator("sample", nil); err == nil { + t.Fatalf("expected error for duplicated creator") + } +} + +func TestCreateOutputConfig(t *testing.T) { + resetConfigCreators() + + if _, err := createOutputConfig("unknown", ActionAdd, nil); err == nil { + t.Fatalf("expected error for unknown config type") + } + + _ = RegisterOutputConfigCreator("known", func(a Action, data json.RawMessage) (OutputConverter, error) { + return mockOutputConverter{typ: "known", action: a}, nil + }) + + cfg, err := createOutputConfig("known", ActionOutput, nil) + if err != nil { + t.Fatalf("createOutputConfig() error = %v", err) + } + if cfg.GetAction() != ActionOutput || cfg.GetType() != "known" { + t.Fatalf("unexpected converter: %v %v", cfg.GetType(), cfg.GetAction()) + } +} + +func TestInputConvConfigUnmarshalJSON(t *testing.T) { + resetConfigCreators() + _ = RegisterInputConfigCreator("stub", func(a Action, data json.RawMessage) (InputConverter, error) { + return mockInputConverter{typ: "stub", action: a}, nil + }) + + jsonData := []byte(`{"type":"stub","action":"add","args":{}}`) + var cfg inputConvConfig + if err := cfg.UnmarshalJSON(jsonData); err != nil { + t.Fatalf("UnmarshalJSON() error = %v", err) + } + if cfg.iType != "stub" || cfg.action != ActionAdd { + t.Fatalf("unexpected values: %v %v", cfg.iType, cfg.action) + } + + if err := cfg.UnmarshalJSON([]byte(`{"type":"stub","action":"invalid"}`)); err == nil { + t.Fatalf("expected error for invalid action") + } + + if err := cfg.UnmarshalJSON([]byte(`{"type":"unknown","action":"add"}`)); err == nil { + t.Fatalf("expected error for unknown type") + } + + if err := cfg.UnmarshalJSON([]byte(`{`)); err == nil { + t.Fatalf("expected json error") + } +} + +func TestOutputConvConfigUnmarshalJSON(t *testing.T) { + resetConfigCreators() + _ = RegisterOutputConfigCreator("stub", func(a Action, data json.RawMessage) (OutputConverter, error) { + return mockOutputConverter{typ: "stub", action: a}, nil + }) + + jsonData := []byte(`{"type":"stub","args":{}}`) + var cfg outputConvConfig + if err := cfg.UnmarshalJSON(jsonData); err != nil { + t.Fatalf("UnmarshalJSON() error = %v", err) + } + if cfg.iType != "stub" || cfg.action != ActionOutput { + t.Fatalf("unexpected values: %v %v", cfg.iType, cfg.action) + } + + if err := cfg.UnmarshalJSON([]byte(`{"type":"stub","action":"invalid"}`)); err == nil { + t.Fatalf("expected error for invalid action") + } + + if err := cfg.UnmarshalJSON([]byte(`{"type":"unknown","action":"add"}`)); err == nil { + t.Fatalf("expected error for unknown type") + } + + if err := cfg.UnmarshalJSON([]byte(`{`)); err == nil { + t.Fatalf("expected json error") + } +} diff --git a/lib/container_test.go b/lib/container_test.go new file mode 100644 index 00000000..6b4827d7 --- /dev/null +++ b/lib/container_test.go @@ -0,0 +1,445 @@ +package lib + +import ( + "net/netip" + "testing" + + "go4.org/netipx" +) + +func TestNewContainerBasicOperations(t *testing.T) { + c := NewContainer().(*container) + if !c.isValid() { + t.Fatalf("new container should be valid") + } + if c.Len() != 0 { + t.Fatalf("expected len 0, got %d", c.Len()) + } + if entry, ok := c.GetEntry("missing"); ok || entry != nil { + t.Fatalf("expected missing entry") + } + + invalid := &container{} + if entry, ok := invalid.GetEntry("anything"); ok || entry != nil { + t.Fatalf("expected invalid container to return nil entry") + } + if invalid.Len() != 0 { + t.Fatalf("invalid container length should be 0") + } +} + +func TestContainerAddAndMerge(t *testing.T) { + c := NewContainer() + + entry := NewEntry("test") + if err := entry.AddPrefix("10.0.0.0/24"); err != nil { + t.Fatalf("AddPrefix() error = %v", err) + } + if err := entry.AddPrefix("2001:db8::/32"); err != nil { + t.Fatalf("AddPrefix() error = %v", err) + } + + if err := c.Add(entry); err != nil { + t.Fatalf("Add() error = %v", err) + } + + got, ok := c.GetEntry("TEST") + if !ok || got == nil { + t.Fatalf("entry not found after add") + } + + // merge with existing entry, should append new prefixes + entry2 := NewEntry("test") + if err := entry2.AddPrefix("192.0.2.0/24"); err != nil { + t.Fatalf("AddPrefix() error = %v", err) + } + if err := c.Add(entry2); err != nil { + t.Fatalf("Add() error = %v", err) + } + + ipset, err := got.GetIPv4Set() + if err != nil { + t.Fatalf("GetIPv4Set() error = %v", err) + } + if !ipset.Contains(netip.MustParseAddr("10.0.0.1")) || !ipset.Contains(netip.MustParseAddr("192.0.2.1")) { + t.Fatalf("merged IPv4 set missing prefixes") + } +} + +func TestContainerAddWithIgnore(t *testing.T) { + c := NewContainer() + + entry := NewEntry("mix") + if err := entry.AddPrefix("10.1.0.0/16"); err != nil { + t.Fatalf("AddPrefix() error = %v", err) + } + if err := entry.AddPrefix("2001:db8:1::/48"); err != nil { + t.Fatalf("AddPrefix() error = %v", err) + } + + if err := c.Add(entry, IgnoreIPv6); err != nil { + t.Fatalf("Add() error = %v", err) + } + + got, _ := c.GetEntry("mix") + if got.hasIPv6Builder() { + t.Fatalf("expected IPv6 builder to be nil when ignored") + } +} + +func TestContainerRemovePrefixAndEntry(t *testing.T) { + c := NewContainer() + entry := NewEntry("remove") + _ = entry.AddPrefix("10.2.0.0/16") + _ = entry.AddPrefix("2001:db8:2::/48") + _ = c.Add(entry) + + // remove prefix + removeEntry := NewEntry("remove") + _ = removeEntry.AddPrefix("10.2.0.0/24") + if err := c.Remove(removeEntry, CaseRemovePrefix); err != nil { + t.Fatalf("Remove() error = %v", err) + } + + got, _ := c.GetEntry("remove") + ipset, _ := got.GetIPv4Set() + if ipset.Contains(netip.MustParseAddr("10.2.0.1")) { + t.Fatalf("expected prefix to be removed") + } + + // remove only IPv4 builder + removeEntry2 := NewEntry("remove") + _ = c.Remove(removeEntry2, CaseRemoveEntry, IgnoreIPv6) + if got.hasIPv4Builder() { + t.Fatalf("expected IPv4 builder cleared") + } + + // remove the entry entirely + if err := c.Remove(removeEntry2, CaseRemoveEntry); err != nil { + t.Fatalf("Remove() error = %v", err) + } + if c.Len() != 0 { + t.Fatalf("expected container empty after removal") + } +} + +func TestContainerRemoveBranches(t *testing.T) { + c := NewContainer() + entry := NewEntry("rb") + _ = entry.AddPrefix("10.6.0.0/16") + _ = entry.AddPrefix("2001:db8:6::/48") + _ = c.Add(entry) + + r1 := NewEntry("rb") + _ = r1.AddPrefix("2001:db8:6::/48") + if err := c.Remove(r1, CaseRemovePrefix, IgnoreIPv4); err != nil { + t.Fatalf("Remove() error = %v", err) + } + entry.ipv6Set = nil + if set, _ := entry.GetIPv6Set(); set.Contains(netip.MustParseAddr("2001:db8:6::1")) { + t.Fatalf("expected ipv6 prefix removed") + } + + r2 := NewEntry("rb") + _ = r2.AddPrefix("10.6.0.0/16") + if err := c.Remove(r2, CaseRemovePrefix, IgnoreIPv6); err != nil { + t.Fatalf("Remove() error = %v", err) + } + entry.ipv4Set = nil + if set, _ := entry.GetIPv4Set(); set.Contains(netip.MustParseAddr("10.6.0.1")) { + t.Fatalf("expected ipv4 prefix removed") + } + + // Add a new IPv4 prefix and clear only IPv6 builder + _ = entry.AddPrefix("10.6.1.0/24") + if err := c.Remove(NewEntry("rb"), CaseRemoveEntry, IgnoreIPv4); err != nil { + t.Fatalf("Remove() error = %v", err) + } + if entry.hasIPv6Builder() { + t.Fatalf("expected ipv6 builder to be cleared") + } + + // error from invalid builder + bad := NewEntry("rb") + bad.ipv4Builder = &netipx.IPSetBuilder{} + bad.ipv4Builder.AddPrefix(netip.Prefix{}) + if err := c.Remove(bad, CaseRemovePrefix); err == nil { + t.Fatalf("expected error from invalid builder") + } + + badv6 := NewEntry("rb") + badv6.ipv6Builder = &netipx.IPSetBuilder{} + badv6.ipv6Builder.AddPrefix(netip.Prefix{}) + if err := c.Remove(badv6, CaseRemovePrefix); err == nil { + t.Fatalf("expected error from invalid ipv6 builder") + } + + // create missing builders during remove + only6 := NewEntry("only6") + _ = only6.AddPrefix("2001:db8:10::/48") + _ = c.Add(only6) + remove4 := NewEntry("only6") + _ = remove4.AddPrefix("203.0.113.0/24") + if err := c.Remove(remove4, CaseRemovePrefix, IgnoreIPv6); err != nil { + t.Fatalf("Remove() error = %v", err) + } + got, _ := c.GetEntry("only6") + if !got.hasIPv4Builder() { + t.Fatalf("expected ipv4 builder created during remove") + } + + only4 := NewEntry("only4") + _ = only4.AddPrefix("198.51.101.0/24") + _ = c.Add(only4) + remove6 := NewEntry("only4") + _ = remove6.AddPrefix("2001:db8:11::/48") + if err := c.Remove(remove6, CaseRemovePrefix, IgnoreIPv4); err != nil { + t.Fatalf("Remove() error = %v", err) + } + got2, _ := c.GetEntry("only4") + if !got2.hasIPv6Builder() { + t.Fatalf("expected ipv6 builder created during remove") + } + + empty := &Entry{name: "EMPTY"} + c.(*container).entries["EMPTY"] = empty + removeEmpty := NewEntry("empty") + _ = removeEmpty.AddPrefix("10.0.0.0/24") + if err := c.Remove(removeEmpty, CaseRemovePrefix); err != nil { + t.Fatalf("Remove() error = %v", err) + } + if got3, _ := c.GetEntry("empty"); !got3.hasIPv4Builder() || !got3.hasIPv6Builder() { + t.Fatalf("expected builders created in default remove branch") + } + + // unknown remove case with existing entry + if err := c.Remove(NewEntry("only4"), CaseRemove(123)); err == nil { + t.Fatalf("expected error for unknown remove case on existing entry") + } +} +func TestContainerRemoveErrors(t *testing.T) { + c := NewContainer() + entry := NewEntry("missing") + + if err := c.Remove(entry, CaseRemoveEntry); err == nil { + t.Fatalf("expected error when removing missing entry") + } + + if err := c.Remove(entry, CaseRemove(99)); err == nil { + t.Fatalf("expected error for unknown remove case") + } +} + +func TestContainerAddErrorAndIgnoreBranches(t *testing.T) { + c := NewContainer() + valid := NewEntry("mix") + _ = valid.AddPrefix("10.5.0.0/16") + _ = valid.AddPrefix("2001:db8:5::/48") + _ = c.Add(valid) + + // ignore IPv4 when adding a new entry + entry := NewEntry("newone") + _ = entry.AddPrefix("203.0.113.0/24") + _ = entry.AddPrefix("2001:db8:6::/48") + if err := c.Add(entry, IgnoreIPv4); err != nil { + t.Fatalf("Add() error = %v", err) + } + if got, _ := c.GetEntry("newone"); got.hasIPv4Builder() { + t.Fatalf("expected IPv4 builder nil when ignored") + } + + // found=true path with ignore IPv4 branch + moreIPv6 := NewEntry("mix") + _ = moreIPv6.AddPrefix("2001:db8:7::/48") + if err := c.Add(moreIPv6, IgnoreIPv4); err != nil { + t.Fatalf("Add() error = %v", err) + } + existing, _ := c.GetEntry("mix") + ipv6set, _ := existing.GetIPv6Set() + if !ipv6set.Contains(netip.MustParseAddr("2001:db8:7::1")) { + t.Fatalf("expected IPv6 prefix merged") + } + + // found=true path with ignore IPv6 branch + moreIPv4 := NewEntry("mix") + _ = moreIPv4.AddPrefix("10.5.1.0/24") + if err := c.Add(moreIPv4, IgnoreIPv6); err != nil { + t.Fatalf("Add() error = %v", err) + } + existing.ipv4Set = nil + ipv4set, _ := existing.GetIPv4Set() + if !ipv4set.Contains(netip.MustParseAddr("10.5.1.1")) { + t.Fatalf("expected IPv4 prefix merged") + } + + // error from building ipset in found=true path + bad := NewEntry("mix") + bad.ipv4Builder = &netipx.IPSetBuilder{} + bad.ipv4Builder.AddPrefix(netip.Prefix{}) // invalid, accumulates error + if err := c.Add(bad); err == nil { + t.Fatalf("expected error from invalid builder") + } + + bad6 := NewEntry("mix") + bad6.ipv6Builder = &netipx.IPSetBuilder{} + bad6.ipv6Builder.AddPrefix(netip.Prefix{}) + if err := c.Add(bad6); err == nil { + t.Fatalf("expected error from invalid ipv6 builder") + } +} + +func TestContainerAddCreatesMissingBuilders(t *testing.T) { + c := NewContainer() + + partial := NewEntry("partial") + _ = partial.AddPrefix("10.7.0.0/16") + _ = c.Add(partial) + + addIPv6 := NewEntry("partial") + _ = addIPv6.AddPrefix("2001:db8:7::/48") + if err := c.Add(addIPv6, IgnoreIPv4); err != nil { + t.Fatalf("Add() error = %v", err) + } + if got, _ := c.GetEntry("partial"); !got.hasIPv6Builder() { + t.Fatalf("expected ipv6 builder created") + } + + partial2 := NewEntry("partial2") + _ = partial2.AddPrefix("2001:db8:8::/48") + _ = c.Add(partial2) + + addIPv4 := NewEntry("partial2") + _ = addIPv4.AddPrefix("198.18.0.0/16") + if err := c.Add(addIPv4, IgnoreIPv6); err != nil { + t.Fatalf("Add() error = %v", err) + } + if got, _ := c.GetEntry("partial2"); !got.hasIPv4Builder() { + t.Fatalf("expected ipv4 builder created") + } + + partial3 := &Entry{name: "PARTIAL3"} + c.(*container).entries["PARTIAL3"] = partial3 + addBoth := NewEntry("partial3") + _ = addBoth.AddPrefix("198.19.0.0/16") + _ = addBoth.AddPrefix("2001:db8:9::/48") + if err := c.Add(addBoth); err != nil { + t.Fatalf("Add() error = %v", err) + } + if got, _ := c.GetEntry("partial3"); !got.hasIPv4Builder() || !got.hasIPv6Builder() { + t.Fatalf("expected both builders created") + } +} +func TestContainerLookup(t *testing.T) { + c := NewContainer() + entry := NewEntry("zoneA") + _ = entry.AddPrefix("10.3.0.0/16") + _ = entry.AddPrefix("2001:db8:a::/48") + _ = c.Add(entry) + + entry2 := NewEntry("zoneB") + _ = entry2.AddPrefix("2001:db8:3::/48") + _ = entry2.AddPrefix("198.51.100.0/24") + _ = c.Add(entry2) + + t.Run("match ipv4", func(t *testing.T) { + names, ok, err := c.Lookup("10.3.5.1") + if err != nil { + t.Fatalf("Lookup() error = %v", err) + } + if !ok || len(names) != 1 || names[0] != "ZONEA" { + t.Fatalf("Lookup() got %v %v", names, ok) + } + }) + + t.Run("match ipv6 prefix", func(t *testing.T) { + names, ok, err := c.Lookup("2001:db8:3::/48") + if err != nil { + t.Fatalf("Lookup() error = %v", err) + } + if !ok || len(names) != 1 || names[0] != "ZONEB" { + t.Fatalf("Lookup() got %v %v", names, ok) + } + }) + + t.Run("match ipv4 prefix", func(t *testing.T) { + names, ok, err := c.Lookup("10.3.0.0/16") + if err != nil { + t.Fatalf("Lookup() error = %v", err) + } + if !ok || len(names) != 1 || names[0] != "ZONEA" { + t.Fatalf("Lookup() got %v %v", names, ok) + } + }) + + t.Run("search list filters", func(t *testing.T) { + names, ok, err := c.Lookup("10.3.5.1", "zoneB") + if err != nil { + t.Fatalf("Lookup() error = %v", err) + } + if ok || len(names) != 0 { + t.Fatalf("expected no results when filtered out, got %v", names) + } + }) + + t.Run("invalid input", func(t *testing.T) { + if _, _, err := c.Lookup("not-an-ip"); err == nil { + t.Fatalf("expected error for invalid input") + } + }) + + t.Run("ipv6 address lookup", func(t *testing.T) { + names, ok, err := c.Lookup("2001:db8:a::1") + if err != nil || !ok || len(names) != 1 { + t.Fatalf("Lookup() ipv6 = %v %v %v", names, ok, err) + } + }) + + t.Run("invalid prefix string", func(t *testing.T) { + if _, _, err := c.Lookup("bad/64"); err == nil { + t.Fatalf("expected error for invalid prefix") + } + }) + + t.Run("entry lookup error", func(t *testing.T) { + c2 := NewContainer() + e := NewEntry("only6") + _ = e.AddPrefix("2001:db8::/32") + _ = c2.Add(e) + if _, _, err := c2.Lookup("192.0.2.1"); err == nil { + t.Fatalf("expected error when IPv4 set missing") + } + }) +} + +func TestContainerLoopChannel(t *testing.T) { + c := NewContainer() + entry := NewEntry("loop") + _ = entry.AddPrefix("10.4.0.0/16") + _ = c.Add(entry) + + count := 0 + for range c.(*container).Loop() { + count++ + } + if count != 1 { + t.Fatalf("expected to loop over 1 entry, got %d", count) + } +} + +func TestContainerInternalLookup(t *testing.T) { + c := &container{ + entries: map[string]*Entry{}, + } + e := NewEntry("inner") + _ = e.AddPrefix("203.0.113.0/24") + _ = c.Add(e) + + prefix := netip.MustParsePrefix("203.0.113.0/24") + names, ok, err := c.lookup(prefix, IPv4) + if err != nil { + t.Fatalf("lookup() error = %v", err) + } + if !ok || len(names) != 1 || names[0] != "INNER" { + t.Fatalf("lookup() got %v %v", names, ok) + } +} diff --git a/lib/converter_test.go b/lib/converter_test.go new file mode 100644 index 00000000..9c92a569 --- /dev/null +++ b/lib/converter_test.go @@ -0,0 +1,44 @@ +package lib + +import ( + "strings" + "testing" +) + +func TestRegisterInputConverter(t *testing.T) { + resetInputConverters() + if err := RegisterInputConverter("json", mockInputConverter{typ: "json", action: ActionAdd}); err != nil { + t.Fatalf("RegisterInputConverter() error = %v", err) + } + if err := RegisterInputConverter("json", mockInputConverter{}); err != ErrDuplicatedConverter { + t.Fatalf("expected ErrDuplicatedConverter, got %v", err) + } +} + +func TestRegisterOutputConverter(t *testing.T) { + resetOutputConverters() + if err := RegisterOutputConverter("txt", mockOutputConverter{typ: "txt", action: ActionOutput}); err != nil { + t.Fatalf("RegisterOutputConverter() error = %v", err) + } + if err := RegisterOutputConverter("txt", mockOutputConverter{}); err != ErrDuplicatedConverter { + t.Fatalf("expected ErrDuplicatedConverter, got %v", err) + } +} + +func TestListConverters(t *testing.T) { + resetInputConverters() + resetOutputConverters() + + _ = RegisterInputConverter("b", mockInputConverter{typ: "b", desc: "second"}) + _ = RegisterInputConverter("a", mockInputConverter{typ: "a", desc: "first"}) + _ = RegisterOutputConverter("x", mockOutputConverter{typ: "x", desc: "out"}) + + out := captureOutput(t, func() { + ListInputConverter() + ListOutputConverter() + }) + + if !strings.Contains(out, "a") || !strings.Contains(out, "b") || !strings.Contains(out, "x") { + t.Fatalf("unexpected output: %s", out) + } +} diff --git a/lib/entry_test.go b/lib/entry_test.go new file mode 100644 index 00000000..f3a11989 --- /dev/null +++ b/lib/entry_test.go @@ -0,0 +1,414 @@ +package lib + +import ( + "net" + "net/netip" + "testing" + + "go4.org/netipx" +) + +func TestProcessPrefixVariants(t *testing.T) { + e := NewEntry("proc") + + ipv4 := net.ParseIP("1.1.1.1") + p, ipType, err := e.processPrefix(ipv4) + if err != nil || ipType != IPv4 || p.String() != "1.1.1.1/32" { + t.Fatalf("processPrefix(net.IPv4) = %v %v %v", p, ipType, err) + } + + if _, _, err := e.processPrefix(net.IP{}); err != ErrInvalidIP { + t.Fatalf("expected ErrInvalidIP for empty net.IP, got %v", err) + } + + ipv6 := net.ParseIP("2001:db8::1") + p, ipType, err = e.processPrefix(ipv6) + if err != nil || ipType != IPv6 || p.String() != "2001:db8::1/128" { + t.Fatalf("processPrefix(net.IPv6) = %v %v %v", p, ipType, err) + } + + _, n, _ := net.ParseCIDR("10.0.0.0/24") + p, ipType, err = e.processPrefix(n) + if err != nil || ipType != IPv4 || p.String() != "10.0.0.0/24" { + t.Fatalf("processPrefix(*net.IPNet) = %v %v %v", p, ipType, err) + } + + _, n6, _ := net.ParseCIDR("2001:db8:ffff::/48") + p, ipType, err = e.processPrefix(n6) + if err != nil || ipType != IPv6 || p.String() != "2001:db8:ffff::/48" { + t.Fatalf("processPrefix(*net.IPNet ipv6) = %v %v %v", p, ipType, err) + } + + badNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPMask{1}} + if _, _, err := e.processPrefix(badNet); err != ErrInvalidIPNet { + t.Fatalf("expected ErrInvalidIPNet, got %v", err) + } + + addr := netip.MustParseAddr("192.0.2.1") + p, ipType, err = e.processPrefix(addr) + if err != nil || ipType != IPv4 || p.String() != "192.0.2.1/32" { + t.Fatalf("processPrefix(netip.Addr) = %v %v %v", p, ipType, err) + } + + ipv6Addr := netip.MustParseAddr("2001:db8::3") + p, ipType, err = e.processPrefix(ipv6Addr) + if err != nil || ipType != IPv6 || p.String() != "2001:db8::3/128" { + t.Fatalf("processPrefix(netip.Addr ipv6) = %v %v %v", p, ipType, err) + } + + addrPtr := netip.MustParseAddr("2001:db8::2") + p, ipType, err = e.processPrefix(&addrPtr) + if err != nil || ipType != IPv6 || p.String() != "2001:db8::2/128" { + t.Fatalf("processPrefix(*netip.Addr) = %v %v %v", p, ipType, err) + } + + addrPtr4 := netip.MustParseAddr("198.18.0.1") + p, ipType, err = e.processPrefix(&addrPtr4) + if err != nil || ipType != IPv4 || p.String() != "198.18.0.1/32" { + t.Fatalf("processPrefix(*netip.Addr ipv4) = %v %v %v", p, ipType, err) + } + + prefix := netip.MustParsePrefix("198.51.100.0/24") + p, ipType, err = e.processPrefix(prefix) + if err != nil || ipType != IPv4 || p.String() != "198.51.100.0/24" { + t.Fatalf("processPrefix(netip.Prefix) = %v %v %v", p, ipType, err) + } + + ipv6PrefixVal := netip.MustParsePrefix("2001:db8:abcd::/48") + if p, ipType, err := e.processPrefix(ipv6PrefixVal); err != nil || ipType != IPv6 || p.String() != "2001:db8:abcd::/48" { + t.Fatalf("processPrefix(netip.Prefix ipv6) = %v %v %v", p, ipType, err) + } + + prefixPtr := netip.MustParsePrefix("2001:db8:ffff::/48") + p, ipType, err = e.processPrefix(&prefixPtr) + if err != nil || ipType != IPv6 || p.String() != "2001:db8:ffff::/48" { + t.Fatalf("processPrefix(*netip.Prefix) = %v %v %v", p, ipType, err) + } + + prefixPtr4 := netip.MustParsePrefix("198.51.100.0/24") + p, ipType, err = e.processPrefix(&prefixPtr4) + if err != nil || ipType != IPv4 || p.String() != "198.51.100.0/24" { + t.Fatalf("processPrefix(*netip.Prefix ipv4) = %v %v %v", p, ipType, err) + } + + // IPv4-mapped IPv6 with insufficient bits should be rejected + badPrefix := netip.MustParsePrefix("::ffff:192.0.2.1/95") + if _, _, err := e.processPrefix(badPrefix); err != ErrInvalidPrefix { + t.Fatalf("expected ErrInvalidPrefix, got %v", err) + } + + mappedPrefix := netip.MustParsePrefix("::ffff:192.0.2.0/120") + if p, ipType, err := e.processPrefix(mappedPrefix); err != nil || ipType != IPv4 || p.String() != "192.0.2.0/24" { + t.Fatalf("processPrefix(mappedPrefix) = %v %v %v", p, ipType, err) + } + + invalidPrefix4 := netip.PrefixFrom(netip.MustParseAddr("1.1.1.1"), 40) + if _, _, err := e.processPrefix(invalidPrefix4); err != ErrInvalidPrefix { + t.Fatalf("expected ErrInvalidPrefix for invalid ipv4 prefix, got %v", err) + } + + invalidPrefix6 := netip.PrefixFrom(netip.MustParseAddr("2001:db8::1"), 200) + if _, _, err := e.processPrefix(invalidPrefix6); err != ErrInvalidPrefix { + t.Fatalf("expected ErrInvalidPrefix for invalid ipv6 prefix, got %v", err) + } + + invalidPrefix4Ptr := invalidPrefix4 + if _, _, err := e.processPrefix(&invalidPrefix4Ptr); err != ErrInvalidPrefix { + t.Fatalf("expected ErrInvalidPrefix for invalid ipv4 prefix pointer, got %v", err) + } + + invalidPrefix6Ptr := invalidPrefix6 + if _, _, err := e.processPrefix(&invalidPrefix6Ptr); err != ErrInvalidPrefix { + t.Fatalf("expected ErrInvalidPrefix for invalid ipv6 prefix pointer, got %v", err) + } + + zeroPrefix := netip.Prefix{} + if _, _, err := e.processPrefix(zeroPrefix); err != ErrInvalidIPLength { + t.Fatalf("expected ErrInvalidIPLength for zero prefix, got %v", err) + } + + badPrefixPtr := badPrefix + if _, _, err := e.processPrefix(&badPrefixPtr); err != ErrInvalidPrefix { + t.Fatalf("expected ErrInvalidPrefix for pointer bad prefix, got %v", err) + } + + zeroPrefixPtr := zeroPrefix + if _, _, err := e.processPrefix(&zeroPrefixPtr); err != ErrInvalidIPLength { + t.Fatalf("expected ErrInvalidIPLength for zero prefix pointer, got %v", err) + } + + addrZero := netip.Addr{} + if _, _, err := e.processPrefix(&addrZero); err != ErrInvalidIPLength { + t.Fatalf("expected ErrInvalidIPLength for zero addr pointer, got %v", err) + } + + mappedPrefixPtr := mappedPrefix + if p, ipType, err := e.processPrefix(&mappedPrefixPtr); err != nil || ipType != IPv4 || p.String() != "192.0.2.0/24" { + t.Fatalf("processPrefix(mappedPrefixPtr) = %v %v %v", p, ipType, err) + } + + if _, _, err := e.processPrefix(netip.Addr{}); err != ErrInvalidIPLength { + t.Fatalf("expected ErrInvalidIPLength, got %v", err) + } + + if _, _, err := e.processPrefix("1.2.3.4"); err != nil { + t.Fatalf("processPrefix(string ip) error = %v", err) + } + + if _, _, err := e.processPrefix("2001:db8::1"); err != nil { + t.Fatalf("processPrefix(string ipv6) error = %v", err) + } + + if _, _, err := e.processPrefix("10.0.0.0/8"); err != nil { + t.Fatalf("processPrefix(string cidr) error = %v", err) + } + + if _, _, err := e.processPrefix("2001:db8::/32"); err != nil { + t.Fatalf("processPrefix(string cidr ipv6) error = %v", err) + } + + if _, _, err := e.processPrefix("invalid/24"); err != ErrInvalidCIDR { + t.Fatalf("expected ErrInvalidCIDR, got %v", err) + } + + if _, _, err := e.processPrefix(" //comment"); err != ErrCommentLine { + t.Fatalf("expected ErrCommentLine, got %v", err) + } + + if _, _, err := e.processPrefix(123); err != ErrInvalidPrefixType { + t.Fatalf("expected ErrInvalidPrefixType, got %v", err) + } +} + +func TestEntryAddAndRemovePrefix(t *testing.T) { + e := NewEntry("demo") + + if err := e.AddPrefix("10.0.0.0/24"); err != nil { + t.Fatalf("AddPrefix() error = %v", err) + } + if err := e.AddPrefix("2001:db8::/32"); err != nil { + t.Fatalf("AddPrefix() error = %v", err) + } + + ipv4set, err := e.GetIPv4Set() + if err != nil || !ipv4set.Contains(netip.MustParseAddr("10.0.0.1")) { + t.Fatalf("IPv4 set missing data: %v %v", ipv4set, err) + } + + ipv6set, err := e.GetIPv6Set() + if err != nil || !ipv6set.Contains(netip.MustParseAddr("2001:db8::1")) { + t.Fatalf("IPv6 set missing data: %v %v", ipv6set, err) + } + + if err := e.RemovePrefix("10.0.0.0/24"); err != nil { + t.Fatalf("RemovePrefix() error = %v", err) + } + e.ipv4Set = nil + ipv4set, _ = e.GetIPv4Set() + if ipv4set.Contains(netip.MustParseAddr("10.0.0.1")) { + t.Fatalf("prefix should be removed") + } + + if err := e.RemovePrefix("2001:db8::/32"); err != nil { + t.Fatalf("RemovePrefix() error = %v", err) + } + e.ipv6Set = nil + ipv6set, _ = e.GetIPv6Set() + if ipv6set.Contains(netip.MustParseAddr("2001:db8::1")) { + t.Fatalf("ipv6 prefix should be removed") + } + + if err := e.RemovePrefix("invalid"); err == nil { + t.Fatalf("expected error for invalid prefix") + } +} + +func TestEntryAddRemoveInvalidIPType(t *testing.T) { + e := NewEntry("invalid") + if err := e.add(nil, IPType("unknown")); err != ErrInvalidIPType { + t.Fatalf("expected ErrInvalidIPType, got %v", err) + } + if err := e.remove(nil, IPType("unknown")); err != ErrInvalidIPType { + t.Fatalf("expected ErrInvalidIPType, got %v", err) + } +} + +func TestEntryMarshalFunctions(t *testing.T) { + e := NewEntry("marshal") + _ = e.AddPrefix("203.0.113.0/24") + _ = e.AddPrefix("2001:db8:abcd::/48") + + prefixes, err := e.MarshalPrefix() + if err != nil || len(prefixes) != 2 { + t.Fatalf("MarshalPrefix() = %v, %v", prefixes, err) + } + + prefixes, err = e.MarshalPrefix(IgnoreIPv6) + if err != nil || len(prefixes) != 1 { + t.Fatalf("MarshalPrefix(IgnoreIPv6) = %v, %v", prefixes, err) + } + + ranges, err := e.MarshalIPRange() + if err != nil || len(ranges) != 2 { + t.Fatalf("MarshalIPRange() = %v, %v", ranges, err) + } + + ranges, err = e.MarshalIPRange(IgnoreIPv4) + if err != nil || len(ranges) != 1 { + t.Fatalf("MarshalIPRange(IgnoreIPv4) = %v, %v", ranges, err) + } + + text, err := e.MarshalText() + if err != nil || len(text) != 2 { + t.Fatalf("MarshalText() = %v, %v", text, err) + } + + // Ignore IPv4 results + prefixes, err = e.MarshalPrefix(IgnoreIPv4) + if err != nil || len(prefixes) != 1 { + t.Fatalf("MarshalPrefix(IgnoreIPv4) = %v, %v", prefixes, err) + } + + text, err = e.MarshalText(IgnoreIPv4) + if err != nil || len(text) != 1 { + t.Fatalf("MarshalText(IgnoreIPv4) = %v, %v", text, err) + } + + text, err = e.MarshalText(IgnoreIPv6) + if err != nil || len(text) != 1 { + t.Fatalf("MarshalText(IgnoreIPv6) = %v, %v", text, err) + } +} + +func TestEntryMarshalErrors(t *testing.T) { + e := NewEntry("empty") + + if _, err := e.MarshalPrefix(); err == nil { + t.Fatalf("expected error for empty entry") + } + if _, err := e.MarshalIPRange(); err == nil { + t.Fatalf("expected error for empty entry") + } + if _, err := e.MarshalText(); err == nil { + t.Fatalf("expected error for empty entry") + } +} + +func TestEntryBuildIPSetError(t *testing.T) { + e := NewEntry("errbuild") + builder := &netipx.IPSetBuilder{} + builder.AddPrefix(netip.Prefix{}) // invalid prefix triggers builder error + e.ipv4Builder = builder + + if _, err := e.GetIPv4Set(); err == nil { + t.Fatalf("expected buildIPSet error") + } + + builder6 := &netipx.IPSetBuilder{} + builder6.AddPrefix(netip.Prefix{}) + e.ipv6Builder = builder6 + if _, err := e.GetIPv6Set(); err == nil { + t.Fatalf("expected buildIPSet error for ipv6") + } + + if _, err := e.MarshalPrefix(); err == nil { + t.Fatalf("expected MarshalPrefix error from builder") + } + + if _, err := e.MarshalIPRange(); err == nil { + t.Fatalf("expected MarshalIPRange error from builder") + } + + if _, err := e.MarshalText(); err == nil { + t.Fatalf("expected MarshalText error from builder") + } + + // Use fresh entries to ensure builder errors are preserved + e2 := NewEntry("errbuild2") + b2 := &netipx.IPSetBuilder{} + b2.AddPrefix(netip.Prefix{}) + e2.ipv4Builder = b2 + if _, err := e2.MarshalPrefix(); err == nil { + t.Fatalf("expected MarshalPrefix error from builder") + } + + e3 := NewEntry("errbuild3") + b3 := &netipx.IPSetBuilder{} + b3.AddPrefix(netip.Prefix{}) + e3.ipv4Builder = b3 + if _, err := e3.MarshalIPRange(); err == nil { + t.Fatalf("expected MarshalIPRange error from builder") + } + + e4 := NewEntry("errbuild4") + b4 := &netipx.IPSetBuilder{} + b4.AddPrefix(netip.Prefix{}) + e4.ipv4Builder = b4 + if _, err := e4.MarshalText(); err == nil { + t.Fatalf("expected MarshalText error from builder") + } +} + +func TestEntryGetSetErrors(t *testing.T) { + e := NewEntry("sets") + _ = e.AddPrefix("2001:db8::/32") + + if _, err := e.GetIPv4Set(); err == nil { + t.Fatalf("expected error for missing IPv4 set") + } + + e2 := NewEntry("sets2") + _ = e2.AddPrefix("192.0.2.0/24") + if _, err := e2.GetIPv6Set(); err == nil { + t.Fatalf("expected error for missing IPv6 set") + } +} + +func TestAddPrefixErrorPath(t *testing.T) { + e := NewEntry("err") + if err := e.AddPrefix("bad-prefix"); err == nil { + t.Fatalf("expected error for bad prefix") + } + if err := e.RemovePrefix("bad-prefix"); err == nil { + t.Fatalf("expected error for bad prefix removal") + } +} + +func TestEntryCommentLineHandling(t *testing.T) { + e := NewEntry("comment") + if err := e.AddPrefix("# this is comment"); err != ErrInvalidIPType { + t.Fatalf("expected ErrInvalidIPType for comment line, got %v", err) + } + if err := e.RemovePrefix("// another comment"); err != ErrInvalidIPType { + t.Fatalf("expected ErrInvalidIPType for comment line removal, got %v", err) + } +} + +func TestEntryMarshalIgnoreIPv6(t *testing.T) { + e := NewEntry("ignore6") + _ = e.AddPrefix("2001:db8::/32") + _ = e.AddPrefix("198.51.100.0/24") + + ranges, err := e.MarshalIPRange(IgnoreIPv6) + if err != nil || len(ranges) != 1 { + t.Fatalf("MarshalIPRange(IgnoreIPv6) = %v, %v", ranges, err) + } +} + +func TestEntryBuildIPSetReuse(t *testing.T) { + e := NewEntry("reuse") + builder := netipx.IPSetBuilder{} + builder.AddPrefix(netip.MustParsePrefix("10.10.0.0/16")) + e.ipv4Builder = &builder + + if err := e.buildIPSet(); err != nil { + t.Fatalf("buildIPSet() error = %v", err) + } + if !e.hasIPv4Set() { + t.Fatalf("expected ipv4 set to be built") + } + // Second call should be a no-op but still succeed + if err := e.buildIPSet(); err != nil { + t.Fatalf("buildIPSet() second call error = %v", err) + } +} diff --git a/lib/helpers_test.go b/lib/helpers_test.go new file mode 100644 index 00000000..2ea80692 --- /dev/null +++ b/lib/helpers_test.go @@ -0,0 +1,100 @@ +package lib + +import ( + "bytes" + "io" + "os" + "testing" +) + +type mockInputConverter struct { + typ string + action Action + desc string + err error + inputFn func(Container) (Container, error) +} + +func (m mockInputConverter) GetType() string { + return m.typ +} + +func (m mockInputConverter) GetAction() Action { + return m.action +} + +func (m mockInputConverter) GetDescription() string { + if m.desc != "" { + return m.desc + } + return "mock input converter" +} + +func (m mockInputConverter) Input(c Container) (Container, error) { + if m.inputFn != nil { + return m.inputFn(c) + } + return c, m.err +} + +type mockOutputConverter struct { + typ string + action Action + desc string + err error + outFn func(Container) error +} + +func (m mockOutputConverter) GetType() string { + return m.typ +} + +func (m mockOutputConverter) GetAction() Action { + return m.action +} + +func (m mockOutputConverter) GetDescription() string { + if m.desc != "" { + return m.desc + } + return "mock output converter" +} + +func (m mockOutputConverter) Output(c Container) error { + if m.outFn != nil { + return m.outFn(c) + } + return m.err +} + +func captureOutput(t *testing.T, fn func()) string { + t.Helper() + + stdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + fn() + + _ = w.Close() + os.Stdout = stdout + + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + _ = r.Close() + + return buf.String() +} + +func resetInputConverters() { + inputConverterMap = make(map[string]InputConverter) +} + +func resetOutputConverters() { + outputConverterMap = make(map[string]OutputConverter) +} + +func resetConfigCreators() { + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) +} diff --git a/lib/instance_test.go b/lib/instance_test.go new file mode 100644 index 00000000..496037d4 --- /dev/null +++ b/lib/instance_test.go @@ -0,0 +1,143 @@ +package lib + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func setupConfigCreators() { + resetConfigCreators() + _ = RegisterInputConfigCreator("stubinput", func(a Action, data json.RawMessage) (InputConverter, error) { + return mockInputConverter{typ: "stubinput", action: a}, nil + }) + _ = RegisterOutputConfigCreator("stuboutput", func(a Action, data json.RawMessage) (OutputConverter, error) { + return mockOutputConverter{typ: "stuboutput", action: a}, nil + }) +} + +func TestInitConfigFromBytes(t *testing.T) { + setupConfigCreators() + inst, _ := NewInstance() + + content := []byte(` + { + // comment + "input": [ + {"type": "stubinput", "action": "add", "args": {}}, + ], + "output": [ + {"type": "stuboutput", "args": {}}, + ], + } + `) + + if err := inst.InitConfigFromBytes(content); err != nil { + t.Fatalf("InitConfigFromBytes() error = %v", err) + } + if len(inst.(*instance).input) != 1 || len(inst.(*instance).output) != 1 { + t.Fatalf("expected converters to be loaded") + } + + if err := inst.InitConfigFromBytes([]byte(`{`)); err == nil { + t.Fatalf("expected JSON error") + } +} + +func TestInitConfigLocalAndRemote(t *testing.T) { + setupConfigCreators() + inst, _ := NewInstance() + + data := `{"input":[{"type":"stubinput","action":"add","args":{}}],"output":[{"type":"stuboutput","args":{}}]}` + + tmp, err := os.CreateTemp("", "config*.json") + if err != nil { + t.Fatalf("CreateTemp error: %v", err) + } + defer os.Remove(tmp.Name()) + if _, err := tmp.WriteString(data); err != nil { + t.Fatalf("write temp file error: %v", err) + } + tmp.Close() + + if err := inst.InitConfig(tmp.Name()); err != nil { + t.Fatalf("InitConfig(local) error = %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(data)) + })) + defer server.Close() + + if err := inst.InitConfig(server.URL); err != nil { + t.Fatalf("InitConfig(remote) error = %v", err) + } + + if err := inst.InitConfig("non-existent.json"); err == nil { + t.Fatalf("expected error for missing file") + } + + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer errorServer.Close() + if err := inst.InitConfig(errorServer.URL); err == nil { + t.Fatalf("expected error for remote failure") + } +} + +func TestInstanceRun(t *testing.T) { + inst, _ := NewInstance() + + if err := inst.Run(); err == nil { + t.Fatalf("expected error when no input/output configured") + } + + inputErr := errors.New("input fail") + inst.AddInput(mockInputConverter{inputFn: func(c Container) (Container, error) { + return c, inputErr + }}) + inst.AddOutput(mockOutputConverter{}) + if err := inst.Run(); !errors.Is(err, inputErr) { + t.Fatalf("expected input error, got %v", err) + } + + inst.ResetInput() + inst.ResetOutput() + + outputErr := errors.New("output fail") + inst.AddInput(mockInputConverter{inputFn: func(c Container) (Container, error) { + return c, nil + }}) + inst.AddOutput(mockOutputConverter{outFn: func(c Container) error { + return outputErr + }}) + if err := inst.Run(); !errors.Is(err, outputErr) { + t.Fatalf("expected output error, got %v", err) + } + + inst.ResetInput() + inst.ResetOutput() + + inputCalled := false + outputCalled := false + inst.AddInput(mockInputConverter{inputFn: func(c Container) (Container, error) { + inputCalled = true + return c, nil + }}) + inst.AddOutput(mockOutputConverter{outFn: func(c Container) error { + outputCalled = true + return nil + }}) + + if err := inst.Run(); err != nil { + t.Fatalf("Run() error = %v", err) + } + if !inputCalled || !outputCalled { + t.Fatalf("expected both input and output to be called") + } +} diff --git a/lib/lib_test.go b/lib/lib_test.go new file mode 100644 index 00000000..f851994c --- /dev/null +++ b/lib/lib_test.go @@ -0,0 +1,12 @@ +package lib + +import "testing" + +func TestIgnoreIPHelpers(t *testing.T) { + if got := IgnoreIPv4(); got != IPv4 { + t.Fatalf("IgnoreIPv4() = %s, want %s", got, IPv4) + } + if got := IgnoreIPv6(); got != IPv6 { + t.Fatalf("IgnoreIPv6() = %s, want %s", got, IPv6) + } +} |
